1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
17
18#include <stddef.h>
19#include <stdlib.h>
20
21#include <algorithm>
22#include <cstdint>
23#include <memory>
24#include <string>
25#include <utility>
26#include <vector>
27
28#include "absl/base/attributes.h"
29#include "absl/container/flat_hash_map.h"
30#include "absl/container/flat_hash_set.h"
31#include "absl/strings/match.h"
32#include "absl/strings/str_cat.h"
33#include "absl/strings/str_format.h"
34#include "absl/strings/str_join.h"
35#include "absl/strings/string_view.h"
36#include "flatbuffers/flatbuffers.h" // from @flatbuffers
37#include "flatbuffers/flexbuffers.h" // from @flatbuffers
38#include "llvm/ADT/ArrayRef.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/None.h"
41#include "llvm/ADT/Optional.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/StringRef.h"
44#include "llvm/Support/Casting.h"
45#include "llvm/Support/CommandLine.h"
46#include "llvm/Support/FormatVariadic.h"
47#include "llvm/Support/ToolOutputFile.h"
48#include "llvm/Support/raw_ostream.h"
49#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
50#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
51#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
52#include "mlir/IR/Attributes.h" // from @llvm-project
53#include "mlir/IR/Builders.h" // from @llvm-project
54#include "mlir/IR/BuiltinOps.h" // from @llvm-project
55#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
56#include "mlir/IR/Location.h" // from @llvm-project
57#include "mlir/IR/MLIRContext.h" // from @llvm-project
58#include "mlir/IR/Operation.h" // from @llvm-project
59#include "mlir/IR/Types.h" // from @llvm-project
60#include "mlir/IR/Value.h" // from @llvm-project
61#include "mlir/Support/LogicalResult.h" // from @llvm-project
62#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project
63#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h"
64#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
65#include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
66#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
67#include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h"
68#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
69#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
70#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
71#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
72#include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h"
73#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
74#include "tensorflow/compiler/xla/statusor.h"
75#include "tensorflow/core/framework/attr_value.pb.h"
76#include "tensorflow/core/framework/node_def.pb.h"
77#include "tensorflow/core/platform/errors.h"
78#include "tensorflow/core/platform/logging.h"
79#include "tensorflow/core/platform/status.h"
80#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
81#include "tensorflow/lite/experimental/remat/metadata_util.h"
82#include "tensorflow/lite/kernels/internal/kernel_utils.h"
83#include "tensorflow/lite/schema/schema_conversion_utils.h"
84#include "tensorflow/lite/schema/schema_generated.h"
85#include "tensorflow/lite/string_util.h"
86#include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
87#include "tensorflow/lite/tools/versioning/op_version.h"
88#include "tensorflow/lite/tools/versioning/runtime_version.h"
89#include "tensorflow/lite/version.h"
90
91using llvm::dyn_cast;
92using llvm::formatv;
93using llvm::isa;
94using llvm::Optional;
95using llvm::StringRef;
96using llvm::Twine;
97using mlir::Dialect;
98using mlir::ElementsAttr;
99using mlir::MLIRContext;
100using mlir::ModuleOp;
101using mlir::NoneType;
102using mlir::Operation;
103using mlir::Region;
104using mlir::StringAttr;
105using mlir::TensorType;
106using mlir::Type;
107using mlir::UnknownLoc;
108using mlir::Value;
109using mlir::WalkResult;
110using mlir::func::FuncOp;
111using tensorflow::OpOrArgLocNameMapper;
112using tensorflow::OpOrArgNameMapper;
113using tensorflow::Status;
114using tflite::flex::IsAllowlistedFlexOp;
115using xla::StatusOr;
116
117template <typename T>
118using BufferOffset = flatbuffers::Offset<T>;
119
120template <typename T>
121using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>;
122
123using CustomOptionsOffset = VectorBufferOffset<uint8_t>;
124
125namespace error = tensorflow::error;
126namespace tfl = mlir::TFL;
127
128ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
129
130// Use initial buffer size in flatbuffer builder to be same as the initial size
131// used by the TOCO export. (It does not explain rationale for this choice.)
132constexpr size_t kInitialBufferSize = 10240;
133
134// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
135// Since tflite doesn't support unsigned for other types, returns error if
136// `isSigned` is set to false for other types.
137static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
138 bool is_signed = true) {
139 if (!is_signed && type.isSignlessInteger(8)) {
140 return tflite::TensorType_UINT8;
141 }
142 if (!is_signed) {
143 return Status(error::INVALID_ARGUMENT,
144 "'isSigned' can only be set for 8-bits integer type");
145 }
146
147 if (type.isF32()) {
148 return tflite::TensorType_FLOAT32;
149 } else if (type.isF16()) {
150 return tflite::TensorType_FLOAT16;
151 } else if (type.isF64()) {
152 return tflite::TensorType_FLOAT64;
153 } else if (type.isa<mlir::TF::StringType>()) {
154 return tflite::TensorType_STRING;
155 } else if (type.isa<mlir::TF::Quint8Type>()) {
156 return tflite::TensorType_UINT8;
157 } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) {
158 auto ftype = complex_type.getElementType();
159 if (ftype.isF32()) {
160 return tflite::TensorType_COMPLEX64;
161 }
162 if (ftype.isF64()) {
163 return tflite::TensorType_COMPLEX128;
164 }
165 return Status(error::INVALID_ARGUMENT, "Unsupported type");
166 } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) {
167 switch (itype.getWidth()) {
168 case 1:
169 return tflite::TensorType_BOOL;
170 case 8:
171 return itype.isUnsigned() ? tflite::TensorType_UINT8
172 : tflite::TensorType_INT8;
173 case 16:
174 return itype.isUnsigned() ? tflite::TensorType_UINT16
175 : tflite::TensorType_INT16;
176 case 32:
177 return itype.isUnsigned() ? tflite::TensorType_UINT32
178 : tflite::TensorType_INT32;
179 case 64:
180 return itype.isUnsigned() ? tflite::TensorType_UINT64
181 : tflite::TensorType_INT64;
182 }
183 } else if (auto q_uniform_type =
184 type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
185 return GetTFLiteType(q_uniform_type.getStorageType(),
186 q_uniform_type.isSigned());
187 } else if (auto q_peraxis_type =
188 type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
189 return GetTFLiteType(q_peraxis_type.getStorageType(),
190 q_peraxis_type.isSigned());
191 } else if (auto q_calibrated_type =
192 type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
193 return GetTFLiteType(q_calibrated_type.getExpressedType());
194 } else if (type.isa<mlir::TF::ResourceType>()) {
195 return tflite::TensorType_RESOURCE;
196 } else if (type.isa<mlir::TF::VariantType>()) {
197 return tflite::TensorType_VARIANT;
198 }
199 // TFLite export fills FLOAT32 for unknown data types. Returning an error
200 // for now for safety and this could be revisited when required.
201 return Status(error::INVALID_ARGUMENT, "Unsupported type");
202}
203
204static bool IsConst(Operation* op) {
205 return isa<mlir::func::ConstantOp, mlir::arith::ConstantOp, mlir::TF::ConstOp,
206 tfl::ConstOp, tfl::QConstOp, tfl::SparseConstOp,
207 tfl::SparseQConstOp, mlir::TFL::NoValueOp>(op);
208}
209
210static bool IsTFResourceOp(Operation* op) {
211 for (const auto& operand : op->getOperands()) {
212 auto elementType = getElementTypeOrSelf(operand.getType());
213 if (elementType.isa<mlir::TF::ResourceType>()) {
214 return true;
215 }
216 }
217 for (const auto& result : op->getResults()) {
218 auto elementType = getElementTypeOrSelf(result.getType());
219 if (elementType.isa<mlir::TF::ResourceType>()) {
220 return true;
221 }
222 }
223 return false;
224}
225
226// Returns whether the current op is not supported by the TF Lite runtime.
227static bool IsUnsupportedFlexOp(const std::string& op_name) {
228 return op_name == "PartitionedCall" || op_name == "StatefulPartitionedCall";
229}
230
231// Create description of operation that could not be converted.
232static std::string GetOpDescriptionForDebug(Operation* inst) {
233 const int kLargeElementsAttr = 16;
234 std::string op_str;
235 llvm::raw_string_ostream os(op_str);
236 inst->getName().print(os);
237 os << "(";
238 if (!inst->getOperandTypes().empty()) {
239 bool first = true;
240 for (Type operand_type : inst->getOperandTypes()) {
241 os << (!first ? ", " : "");
242 first = false;
243 os << operand_type;
244 }
245 }
246 os << ") -> (";
247 if (!inst->getResultTypes().empty()) {
248 bool first = true;
249 for (Type result_type : inst->getResultTypes()) {
250 os << (!first ? ", " : "");
251 first = false;
252 os << result_type;
253 }
254 }
255 os << ")";
256 // Print out attributes except for large elementsattributes (which should
257 // rarely be the cause why the legalization didn't happen).
258 if (!inst->getAttrDictionary().empty()) {
259 os << " : {";
260 bool first = true;
261 for (auto& named_attr : inst->getAttrDictionary()) {
262 os << (!first ? ", " : "");
263 first = false;
264 os << named_attr.getName().getValue() << " = ";
265 if (auto element_attr = named_attr.getValue().dyn_cast<ElementsAttr>()) {
266 if (element_attr.getNumElements() <= kLargeElementsAttr) {
267 element_attr.print(os);
268 } else {
269 os << "<large>";
270 }
271 } else {
272 named_attr.getValue().print(os);
273 }
274 }
275 os << "}";
276 }
277 return os.str();
278}
279
280// Create a summary with the given information regarding op names and
281// descriptions.
282static std::string GetOpsSummary(
283 const std::map<std::string, std::set<std::string>>& ops,
284 const std::string& summary_title) {
285 std::string op_str;
286 llvm::raw_string_ostream os(op_str);
287
288 std::vector<std::string> keys;
289 keys.reserve(ops.size());
290
291 std::vector<std::string> values;
292 values.reserve(ops.size());
293
294 for (auto const& op_name_and_details : ops) {
295 keys.push_back(op_name_and_details.first);
296 for (auto const& op_detail : op_name_and_details.second) {
297 values.push_back(op_detail);
298 }
299 }
300
301 os << summary_title << " ops: " << absl::StrJoin(keys, ", ") << "\n";
302 os << "Details:\n\t" << absl::StrJoin(values, "\n\t");
303
304 return os.str();
305}
306
307template <typename T>
308static bool HasValidTFLiteType(Value value, T& error_handler) {
309 // None type is allowed to represent unspecified operands.
310 if (value.getType().isa<NoneType>()) return true;
311
312 auto type = value.getType().dyn_cast<TensorType>();
313 if (!type) {
314 if (auto op = value.getDefiningOp()) {
315 error_handler.emitError()
316 << '\'' << op << "' should produce value of tensor type instead of "
317 << value.getType();
318 return false;
319 }
320 error_handler.emitError("expected tensor type, got ") << value.getType();
321 return false;
322 }
323
324 Type element_type = type.getElementType();
325 auto status = GetTFLiteType(element_type);
326 if (!status.ok()) {
327 return error_handler.emitError(
328 formatv("Failed to convert element type '{0}': {1}",
329 element_type, status.status().error_message())),
330 false;
331 }
332 return true;
333}
334
335// Returns true if the module holds all the invariants expected by the
336// Translator class.
337// TODO(hinsu): Now that translation is done by making a single pass over the
338// MLIR module, consider inlining these validation checks at the place where
339// these invariants are assumed instead of checking upfront.
340static bool IsValidTFLiteMlirModule(ModuleOp module) {
341 MLIRContext* context = module.getContext();
342
343 // Verify that module has a function named main.
344 FuncOp main_fn = module.lookupSymbol<FuncOp>("main");
345 if (!main_fn) {
346 int entry_func_count = 0;
347 for (auto fn : module.getOps<FuncOp>()) {
348 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
349 if (attrs && !attrs.empty()) {
350 ++entry_func_count;
351 }
352 }
353
354 // Verify that module has a least one enrty function.
355 if (entry_func_count == 0) {
356 return emitError(UnknownLoc::get(context),
357 "should have a least one entry function"),
358 false;
359 }
360 }
361
362 for (auto fn : module.getOps<FuncOp>()) {
363 if (!llvm::hasSingleElement(fn)) {
364 return fn.emitError("should have exactly one basic block"), false;
365 }
366 auto& bb = fn.front();
367
368 for (auto arg : bb.getArguments()) {
369 if (!HasValidTFLiteType(arg, fn)) {
370 auto elementType = getElementTypeOrSelf(arg.getType());
371 if (elementType.isa<mlir::TF::VariantType>()) {
372 return fn.emitError(
373 "function argument uses variant type. Currently, the "
374 "variant type is not natively supported in TFLite. Please "
375 "consider not using the variant type: ")
376 << arg.getType(),
377 false;
378 }
379 return fn.emitError("invalid TFLite type: ") << arg.getType(), false;
380 }
381 }
382
383 // Verify that all operations except the terminator have exactly one
384 // result of type supported by TFLite (or is a ControlType, which
385 // will be removed later by ExtractControlEdges.)
386 for (auto& inst : bb) {
387 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
388
389 for (auto result : inst.getResults()) {
390 if (result.getType().isa<mlir::TFL::ControlType>()) continue;
391 if (!HasValidTFLiteType(result, inst)) {
392 auto elementType = getElementTypeOrSelf(result.getType());
393 if (elementType.isa<mlir::TF::VariantType>()) {
394 return inst.emitError(
395 "operand result uses variant type. Currently, the "
396 "variant type is not natively supported in TFLite. "
397 "Please "
398 "consider not using the variant type: ")
399 << result.getType(),
400 false;
401 }
402 return fn.emitError("invalid TFLite type: ") << result.getType(),
403 false;
404 }
405 }
406 }
407 }
408
409 return true;
410}
411
412static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef(
413 ::mlir::Operation* inst) {
414 // We pass empty string for the original node_def name since Flex runtime
415 // does not care about this being set correctly on node_def. There is no
416 // "easy" (see b/120948529) way yet to get this from MLIR inst.
417 auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef(
418 inst, /*name=*/"", /*ignore_unregistered_attrs=*/true);
419 if (!status_or_node_def.ok()) {
420 inst->emitOpError(
421 Twine("failed to obtain TensorFlow nodedef with status: " +
422 status_or_node_def.status().ToString()));
423 return {};
424 }
425 return std::move(status_or_node_def.value());
426}
427
428// Converts a mlir padding StringRef to TfLitePadding.
429// Returns llvm::None if conversion fails.
430static Optional<TfLitePadding> GetTflitePadding(Operation* inst,
431 llvm::StringRef padding) {
432 const tflite::Padding padding_attr =
433 std::move(llvm::StringSwitch<tflite::Padding>(padding)
434 .Case("SAME", tflite::Padding_SAME)
435 .Case("VALID", tflite::Padding_VALID));
436 if (padding_attr == tflite::Padding_SAME) {
437 return kTfLitePaddingSame;
438 }
439 if (padding_attr == tflite::Padding_VALID) {
440 return kTfLitePaddingValid;
441 }
442
443 return inst->emitOpError() << "Invalid padding attribute: " << padding,
444 llvm::None;
445}
446
447// Extracts TfLitePoolParams from a TFL custom op.
448// Template parameter, TFLOp, should be a TFL custom op containing attributes
449// generated from TfLitePoolParams.
450// Returns llvm::None if conversion fails.
451template <typename TFLOp>
452static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst,
453 TFLOp op) {
454 TfLitePoolParams pool_params;
455 pool_params.stride_height = op.stride_h().getSExtValue();
456 pool_params.stride_width = op.stride_w().getSExtValue();
457 pool_params.filter_height = op.filter_h().getSExtValue();
458 pool_params.filter_width = op.filter_w().getSExtValue();
459 const auto padding = GetTflitePadding(inst, op.padding());
460 if (padding) {
461 pool_params.padding = *padding;
462 pool_params.activation = kTfLiteActNone;
463 pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
464 return pool_params;
465 }
466
467 return llvm::None;
468}
469
470namespace {
471
472// Helper struct that wraps inputs/outputs of a single SignatureDef.
473struct SignatureDefData {
474 // Note, we are using maps here to make order deterministic
475 // for easily testing only.
476
477 // Inputs defined in the signature def mapped to tensor names.
478 std::map<std::string, std::string> inputs;
479 // Outputs defined in the signature def mapped to tensor names.
480 std::map<std::string, std::string> outputs;
481 // Signature key.
482 std::string signature_key;
483 // Subgraph index.
484 uint32_t subgraph_index;
485};
486
487// Translates an MLIR module in TFLite dialect to TFLite FlatBuffer.
488class Translator {
489 public:
490 // Translates the given MLIR module into TFLite FlatBuffer format and returns
491 // the serialized output. Returns llvm::None on unsupported, invalid inputs or
492 // internal error.
493 static Optional<std::string> Translate(
494 ModuleOp module, const toco::TocoFlags& toco_flags,
495 const std::unordered_set<std::string>& tags,
496 OpOrArgNameMapper* op_or_arg_name_mapper,
497 const std::map<std::string, std::string>& metadata);
498
499 private:
500 enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp };
501 explicit Translator(ModuleOp module, const toco::TocoFlags& toco_flags,
502 const std::unordered_set<std::string>& saved_model_tags,
503 OpOrArgNameMapper* op_or_arg_name_mapper,
504 const std::map<std::string, std::string>& metadata)
505 : module_(module),
506 name_mapper_(*op_or_arg_name_mapper),
507 builder_(kInitialBufferSize),
508 saved_model_tags_(saved_model_tags),
509 allow_all_select_tf_ops_(toco_flags.allow_all_select_tf_ops()),
510 select_user_tf_ops_(toco_flags.select_user_tf_ops().begin(),
511 toco_flags.select_user_tf_ops().end()),
512 metadata_(metadata),
513 supported_backends_(toco_flags.supported_backends().begin(),
514 toco_flags.supported_backends().end()) {
515 // The first buffer must be empty according to the schema definition.
516 empty_buffer_ = tflite::CreateBuffer(builder_);
517 buffers_.push_back(empty_buffer_);
518 if (!toco_flags.force_select_tf_ops()) {
519 enabled_op_types_.emplace(OpType::kTfliteBuiltin);
520 }
521 if (toco_flags.enable_select_tf_ops()) {
522 enabled_op_types_.emplace(OpType::kSelectTf);
523 }
524 if (toco_flags.allow_custom_ops()) {
525 enabled_op_types_.emplace(OpType::kCustomOp);
526 }
527 tf_dialect_ =
528 module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>();
529 tfl_dialect_ = module.getContext()
530 ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>();
531 // Right now the TF executor dialect is still needed to build NodeDef.
532 module.getContext()
533 ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
534 }
535
536 Optional<std::string> TranslateInternal();
537
538 // Returns TFLite buffer populated with constant value if the operation is
539 // TFLite constant operation. Otherwise, returns an empty buffer. Emits error
540 // and returns llvm::None on failure.
541 Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst);
542
543 // Build TFLite tensor from the given type. This function is for tfl.lstm
544 // intermediates, which should have UniformQuantizedType.
545 Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType(
546 mlir::Type type, const std::string& name);
547
548 // Builds TF::VariantType from the given element type. Returns llvm::None if
549 // failure. Returns empty vector if the element type is not TF::VariantType or
550 // there is empty TensorType in the TF::VariantType.
551 Optional<std::vector<BufferOffset<tflite::VariantSubType>>>
552 BuildTFVariantType(mlir::Type element_type);
553
554 // Builds TFLite tensor from the given value. `buffer_idx` is index of the
555 // corresponding buffer. Emits error and returns llvm::None on failure.
556 Optional<BufferOffset<tflite::Tensor>> BuildTensor(
557 Value value, const std::string& name, unsigned buffer_idx,
558 const Optional<BufferOffset<tflite::QuantizationParameters>>&
559 quant_parameters);
560
561 // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the
562 // following method to handle TFL::IfOp.
563 BufferOffset<tflite::Operator> BuildIfOperator(
564 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
565 const std::vector<int32_t>& results);
566
567 // Build while operator where cond & body are regions.
568 Optional<BufferOffset<tflite::Operator>> BuildWhileOperator(
569 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
570 const std::vector<int32_t>& results);
571
572 // Build call once operator.
573 BufferOffset<tflite::Operator> BuildCallOnceOperator(
574 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
575 const std::vector<int32_t>& results);
576
577 BufferOffset<tflite::Operator> BuildNumericVerifyOperator(
578 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
579 const std::vector<int32_t>& results);
580
581 BufferOffset<tflite::Operator> BuildCustomOperator(
582 Operation* inst, mlir::TFL::CustomOp op,
583 const std::vector<int32_t>& operands,
584 const std::vector<int32_t>& results);
585
586 Optional<CustomOptionsOffset> CreateFlexOpCustomOptions(
587 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
588
589 Optional<CustomOptionsOffset> CreateCustomOpCustomOptions(
590 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
591
592 std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs(
593 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc);
594
595 // Returns opcode index for op identified by the op_name, if already
596 // available. Otherwise, creates a new OperatorCode using the given `builtin`
597 // operator and associates it with `op_name`.
598 uint32_t GetOpcodeIndex(const std::string& op_name,
599 tflite::BuiltinOperator builtin);
600
601 // Builds operator for the given operation with specified operand and result
602 // tensor indices. Emits an error and returns llvm::None on failure.
603 Optional<BufferOffset<tflite::Operator>> BuildOperator(
604 Operation* inst, std::vector<int32_t> operands,
605 const std::vector<int32_t>& results,
606 const std::vector<int32_t>& intermediates);
607
608 // Returns the quantization parameters for output value of "quant.stats" op.
609 BufferOffset<tflite::QuantizationParameters>
610 GetQuantizationForQuantStatsOpOutput(mlir::quantfork::StatisticsOp stats_op);
611
612 // Build a subgraph with a given name out of the region either corresponding
613 // to a function's body or while op. Modifies *region by calling
614 // ExtractControlEdges.
615 Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph(
616 const std::string& name, Region* region, const int index);
617
618 // Modifies *block by unwrapping all ControlNodeOps. The DAG of the control
619 // dependencies is returned as a vector of its edges, with node indices into
620 // *block.
621 std::vector<std::pair<int, int>> ExtractControlEdges(mlir::Block* block);
622
623 // Builds Metadata with the given `name` and buffer `content`.
624 BufferOffset<tflite::Metadata> BuildMetadata(StringRef name,
625 StringRef content);
626
627 // Encodes the `tfl.metadata` dictionary attribute of the module to the
628 // metadata section in the final model.
629 Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
630 CreateMetadataVector();
631
632 // Builds and returns list of tfl.SignatureDef sections in the model.
633 Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
634 CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs);
635
636 // Returns list of offsets for the passed 'items' in TensorMap structure
637 // inside the flatbuffer.
638 // 'items' is a map from tensor name in signatureDef to tensor name in
639 // the subgraph, specified by the 'subgraph_index' argument.
640 std::vector<BufferOffset<tflite::TensorMap>> GetList(
641 const int subgraph_index,
642 const std::map<std::string, std::string>& items);
643
644 // Uses the tf.entry_function attribute (if set) to initialize the op to name
645 // mapping.
646 void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr);
647
648 // Determines if the specified operation op's operand at operand_index
649 // is marked as a stateful operand.
650 bool IsStatefulOperand(mlir::Operation* op, int operand_index);
651
652 // Returns a unique name for `val`.
653 std::string UniqueName(mlir::Value val);
654
655 BufferOffset<tflite::SparsityParameters> BuildSparsityParameters(
656 const mlir::TFL::SparsityParameterAttr& s_attr);
657
658 bool EstimateArithmeticCount(int64_t* count);
659
660 // Check compatibility with GPU delegate and returns the compatibility.
661 bool CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer);
662
663 ModuleOp module_;
664
665 tensorflow::OpOrArgNameMapper& name_mapper_;
666
667 flatbuffers::FlatBufferBuilder builder_;
668 BufferOffset<tflite::Buffer> empty_buffer_;
669
670 std::vector<BufferOffset<tflite::Buffer>> buffers_;
671 // Maps subgraph index and tensor name in the graph to the tensor index.
672 absl::flat_hash_map<int, absl::flat_hash_map<std::string, int>>
673 tensor_index_map_;
674
675 // Maps op name to index of the corresponding OperatorCode in opcodes_ vector.
676 absl::flat_hash_map<std::string, uint32_t> opcode_index_map_;
677 std::vector<BufferOffset<tflite::OperatorCode>> opcodes_;
678
679 // Maps function name to index of the corresponding subgraph in the FlatBuffer
680 // model.
681 absl::flat_hash_map<std::string, int> subgraph_index_map_;
682 absl::flat_hash_set<OpType> enabled_op_types_;
683
684 // Points to TensorFlow and TFLite dialects, respectively. nullptr if the
685 // dialect is not registered.
686 const Dialect* tf_dialect_;
687 const Dialect* tfl_dialect_;
688
689 // The failed ops during legalization.
690 std::map<std::string, std::set<std::string>> failed_flex_ops_;
691 std::map<std::string, std::set<std::string>> failed_custom_ops_;
692
693 // Ops to provide warning messages.
694 std::map<std::string, std::set<std::string>> custom_ops_;
695 std::map<std::string, std::set<std::string>> flex_ops_;
696
697 // Resource ops to provide warning messages.
698 std::map<std::string, std::set<std::string>> resource_ops_;
699
700 // Set of saved model tags, if any.
701 const std::unordered_set<std::string> saved_model_tags_;
702 // Allows automatic pass through of TF ops as select Tensorflow ops.
703 const bool allow_all_select_tf_ops_;
704 // User's defined ops allowed with Flex.
705 const std::unordered_set<std::string> select_user_tf_ops_;
706 // Map of key value pairs of metadata to export.
707 const std::map<std::string, std::string> metadata_;
708 // User's defined supported backends.
709 const std::unordered_set<std::string> supported_backends_;
710 // A mapping table to mlir::Operation objects for TFL subgraph and operator
711 // index in a flatbuffer.
712 std::vector<std::vector<Operation*>> subgraph_op_inst_map_;
713
714 // Will be populated by ExtractControlEdges to contain the control
715 // dependencies contained in the ControlNodeOps. Will then be used to populate
716 // metadata in the exported flatbuffer file.
717 tflite::ModelControlDependencies model_control_dependencies_;
718};
719
720bool Translator::EstimateArithmeticCount(int64_t* count) {
721 int64_t result = 0;
722 bool encounter_undetermined_mac = false;
723 module_->walk([&](mlir::TFL::TflArithmeticCountOpInterface op) {
724 int64_t mac_count = op.GetArithmeticCount(op);
725 if (mac_count < 0) {
726 encounter_undetermined_mac = true;
727 return;
728 }
729 result += mac_count;
730 });
731
732 *count = result;
733 return !encounter_undetermined_mac;
734}
735
736std::string Translator::UniqueName(mlir::Value val) {
737 return std::string(name_mapper_.GetUniqueName(val));
738}
739
740Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
741 Operation* inst) {
742 ElementsAttr attr;
743 if (auto cst = dyn_cast<mlir::arith::ConstantOp>(inst)) {
744 // arith::ConstantOp have ElementAttr at this point due to validation of the
745 // TFLite module.
746 attr = cst.getValue().cast<ElementsAttr>();
747 } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) {
748 attr = cst.value();
749 } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) {
750 attr = cst.value();
751 } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) {
752 attr = cst.value();
753 } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
754 attr = cst.compressed_data();
755 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
756 attr = cst.compressed_data();
757 } else {
758 return empty_buffer_;
759 }
760
761 tensorflow::Tensor tensor;
762 auto status = tensorflow::ConvertToTensor(attr, &tensor);
763 if (!status.ok()) {
764 inst->emitError(
765 Twine("failed to convert value attribute to tensor with error: " +
766 status.ToString()));
767 return llvm::None;
768 }
769
770 // TensorFlow and TensorFlow Lite use different string encoding formats.
771 // Convert to TensorFlow Lite format is it's a constant string tensor.
772 if (tensor.dtype() == tensorflow::DT_STRING) {
773 ::tflite::DynamicBuffer dynamic_buffer;
774 auto flat = tensor.flat<::tensorflow::tstring>();
775 for (int i = 0; i < flat.size(); ++i) {
776 const auto& str = flat(i);
777 dynamic_buffer.AddString(str.c_str(), str.length());
778 }
779 char* tensor_buffer;
780 int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer);
781 auto buffer_data =
782 builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes);
783 free(tensor_buffer);
784 return tflite::CreateBuffer(builder_, buffer_data);
785 }
786
787 absl::string_view tensor_data = tensor.tensor_data();
788 auto buffer_data = builder_.CreateVector(
789 reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size());
790 return tflite::CreateBuffer(builder_, buffer_data);
791}
792
793Optional<std::vector<BufferOffset<tflite::VariantSubType>>>
794Translator::BuildTFVariantType(mlir::Type element_type) {
795 std::vector<BufferOffset<tflite::VariantSubType>> variant_params;
796 auto variant_type = element_type.dyn_cast<mlir::TF::VariantType>();
797 if (!variant_type) {
798 return variant_params;
799 }
800
801 // We only support up to one nested type in tf_type.variant_type.
802 if (variant_type.getSubtypes().size() > 1) {
803 return llvm::None;
804 }
805 if (variant_type.getSubtypes().empty()) {
806 return variant_params;
807 }
808 mlir::TensorType tensor_type = variant_type.getSubtypes().front();
809 tflite::TensorType tflite_element_type =
810 GetTFLiteType(tensor_type.getElementType()).value();
811 std::vector<int32_t> shape;
812 if (tensor_type.hasRank()) {
813 llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape();
814 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
815 }
816
817 variant_params.push_back(
818 tflite::CreateVariantSubType(builder_, builder_.CreateVector(shape),
819 tflite_element_type, tensor_type.hasRank()));
820 return variant_params;
821}
822
823Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType(
824 mlir::Type type, const std::string& name) {
825 auto tensor_type = type.cast<TensorType>();
826
827 llvm::ArrayRef<int64_t> shape_ref;
828 std::vector<int32_t> shape;
829
830 if (tensor_type.hasRank()) {
831 if (tensor_type.hasStaticShape()) {
832 shape_ref = tensor_type.getShape();
833 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
834 } else {
835 return llvm::None;
836 }
837 }
838
839 auto element_type = tensor_type.getElementType();
840 tflite::TensorType tflite_element_type =
841 GetTFLiteType(tensor_type.getElementType()).value();
842 Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params =
843 BuildTFVariantType(element_type);
844 if (!variant_params.hasValue()) {
845 return llvm::None;
846 }
847 BufferOffset<tflite::QuantizationParameters> q_params = 0;
848 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
849 std::vector<float> scales = {static_cast<float>(qtype.getScale())};
850 std::vector<int64_t> zero_points = {qtype.getZeroPoint()};
851 q_params = tflite::CreateQuantizationParameters(
852 builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
853 builder_.CreateVector<int64_t>(zero_points));
854 } else if (auto qtype =
855 element_type
856 .dyn_cast<mlir::quant::CalibratedQuantizedType>()) {
857 std::vector<float> mins = {static_cast<float>(qtype.getMin())};
858 std::vector<float> maxs = {static_cast<float>(qtype.getMax())};
859 q_params = tflite::CreateQuantizationParameters(
860 builder_, builder_.CreateVector<float>(mins),
861 builder_.CreateVector<float>(maxs));
862 }
863 return tflite::CreateTensor(
864 builder_, builder_.CreateVector(shape), tflite_element_type,
865 /*buffer=*/0, builder_.CreateString(name), q_params,
866 /*is_variable=*/false, /*sparsity=*/0, /*shape_signature=*/0,
867 /*has_rank=*/tensor_type.hasRank(),
868 variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
869}
870
871Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor(
872 Value value, const std::string& name, unsigned buffer_idx,
873 const Optional<BufferOffset<tflite::QuantizationParameters>>&
874 quant_parameters) {
875 auto type = value.getType().cast<TensorType>();
876
877 // TFLite requires tensor shape only for the inputs and constants.
878 // However, we output all known shapes for better round-tripping
879 auto check_shape =
880 [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult {
881 auto is_out_of_range = [](int64_t dim) {
882 return dim > std::numeric_limits<int32_t>::max();
883 };
884
885 if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range))
886 return mlir::emitError(
887 value.getLoc(),
888 "result shape dimensions out of 32 bit int type range");
889
890 return mlir::success();
891 };
892
893 std::vector<int32_t> shape;
894 std::vector<int32_t> shape_signature;
895 auto* inst = value.getDefiningOp();
896 if (type.hasStaticShape()) {
897 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
898 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
899
900 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
901 } else if (inst && IsConst(inst)) {
902 // Const op can have a result of dynamic shaped type (e.g. due to constant
903 // folding), but we can still derive the shape of a constant tensor for
904 // its attribute type.
905 auto tensor_attr = inst->getAttr("value").cast<mlir::TypedAttr>();
906 llvm::ArrayRef<int64_t> shape_ref =
907 tensor_attr.getType().cast<TensorType>().getShape();
908 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
909
910 shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
911 } else if (type.hasRank()) {
912 llvm::ArrayRef<int64_t> shape_ref = type.getShape();
913 if (mlir::failed(check_shape(shape_ref))) return llvm::None;
914
915 shape.reserve(shape_ref.size());
916 for (auto& dim : shape_ref) {
917 shape.push_back(dim == -1 ? 1 : dim);
918 }
919 shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end());
920 }
921
922 BufferOffset<tflite::SparsityParameters> s_params = 0;
923 if (auto* inst = value.getDefiningOp()) {
924 if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) {
925 s_params = BuildSparsityParameters(cst.s_param());
926 } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) {
927 s_params = BuildSparsityParameters(cst.s_param());
928 }
929 }
930
931 Type element_type = type.getElementType();
932 tflite::TensorType tflite_element_type =
933 GetTFLiteType(type.getElementType()).value();
934
935 Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params =
936 BuildTFVariantType(element_type);
937 if (!variant_params.hasValue()) {
938 return llvm::None;
939 }
940
941 BufferOffset<tflite::QuantizationParameters> q_params;
942 if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) {
943 std::vector<float> scales = {static_cast<float>(qtype.getScale())};
944 std::vector<int64_t> zero_points = {qtype.getZeroPoint()};
945 q_params = tflite::CreateQuantizationParameters(
946 // min and max values are not stored in the quantized type from MLIR, so
947 // both are set to 0 in the flatbuffer when they are exported.
948 builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
949 builder_.CreateVector<int64_t>(zero_points));
950 } else if (auto qtype =
951 element_type
952 .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
953 std::vector<float> scales(qtype.getScales().begin(),
954 qtype.getScales().end());
955 std::vector<int64_t> zero_points(qtype.getZeroPoints().begin(),
956 qtype.getZeroPoints().end());
957 q_params = tflite::CreateQuantizationParameters(
958 builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales),
959 builder_.CreateVector<int64_t>(zero_points),
960 tflite::QuantizationDetails_NONE, /*details=*/0,
961 qtype.getQuantizedDimension());
962 } else if (quant_parameters.has_value()) {
963 q_params = quant_parameters.getValue();
964 } else {
965 q_params = tflite::CreateQuantizationParameters(builder_);
966 }
967 // Check if the value's uses includes an op and usage at an operand index
968 // marked as a stateful. If so, set the tensor's is_variable as true
969 // This is v1 ref variable semantics in the TFLite runtime.
970 bool is_variable = false;
971 for (auto& use : value.getUses()) {
972 is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber());
973 if (is_variable) {
974 break;
975 }
976 }
977
978 bool has_rank = type.hasRank();
979
980 if (shape_signature.empty()) {
981 return tflite::CreateTensor(
982 builder_, builder_.CreateVector(shape), tflite_element_type,
983 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
984 /*is_variable=*/is_variable, s_params, /*shape_signature=*/0,
985 /*has_rank=*/has_rank,
986 variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
987 } else {
988 return tflite::CreateTensor(
989 builder_, builder_.CreateVector(shape), tflite_element_type,
990 (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params,
991 /*is_variable=*/is_variable, s_params,
992 /*shape_signature=*/builder_.CreateVector(shape_signature),
993 /*has_rank=*/has_rank,
994 variant_params->empty() ? 0 : builder_.CreateVector(*variant_params));
995 }
996}
997
998BufferOffset<tflite::Operator> Translator::BuildIfOperator(
999 mlir::TF::IfOp op, const std::vector<int32_t>& operands,
1000 const std::vector<int32_t>& results) {
1001 auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF);
1002 int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str());
1003 int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str());
1004 auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index,
1005 else_subgraph_index)
1006 .Union();
1007 auto inputs = builder_.CreateVector(operands);
1008 auto outputs = builder_.CreateVector(results);
1009 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1010 tflite::BuiltinOptions_IfOptions,
1011 builtin_options);
1012}
1013
1014BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
1015 mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
1016 const std::vector<int32_t>& results) {
1017 auto opcode_index =
1018 GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
1019 int init_subgraph_index =
1020 subgraph_index_map_.at(op.session_init_function().str());
1021 auto builtin_options =
1022 tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
1023 auto inputs = builder_.CreateVector(operands);
1024 auto outputs = builder_.CreateVector(results);
1025 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1026 tflite::BuiltinOptions_CallOnceOptions,
1027 builtin_options);
1028}
1029
1030Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator(
1031 mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
1032 const std::vector<int32_t>& results) {
1033 auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE);
1034 auto get_call_index = [&](mlir::Block& b) -> Optional<int> {
1035 if (b.getOperations().size() != 2) return llvm::None;
1036 if (auto call_op = dyn_cast<mlir::func::CallOp>(b.front()))
1037 return subgraph_index_map_.at(call_op.getCallee().str());
1038 return llvm::None;
1039 };
1040 auto body_subgraph_index = get_call_index(op.body().front());
1041 auto cond_subgraph_index = get_call_index(op.cond().front());
1042 if (!body_subgraph_index || !cond_subgraph_index)
1043 return op.emitOpError("only single call cond/body while export supported"),
1044 llvm::None;
1045 auto builtin_options =
1046 tflite::CreateWhileOptions(builder_, *cond_subgraph_index,
1047 *body_subgraph_index)
1048 .Union();
1049 auto inputs = builder_.CreateVector(operands);
1050 auto outputs = builder_.CreateVector(results);
1051 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1052 tflite::BuiltinOptions_WhileOptions,
1053 builtin_options);
1054}
1055
1056BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator(
1057 mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands,
1058 const std::vector<int32_t>& results) {
1059 float tolerance = op.tolerance().convertToFloat();
1060 bool log_if_failed = op.log_if_failed();
1061 auto fbb = std::make_unique<flexbuffers::Builder>();
1062 fbb->Map([&]() {
1063 fbb->Float("tolerance", tolerance);
1064 fbb->Bool("log_if_failed", log_if_failed);
1065 });
1066 fbb->Finish();
1067 auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release());
1068 auto custom_option = f->GetBuffer();
1069 auto opcode_index =
1070 GetOpcodeIndex("NumericVerify", tflite::BuiltinOperator_CUSTOM);
1071 return tflite::CreateOperator(
1072 builder_, opcode_index, builder_.CreateVector(operands),
1073 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
1074 /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option),
1075 tflite::CustomOptionsFormat_FLEXBUFFERS);
1076}
1077
1078BufferOffset<tflite::Operator> Translator::BuildCustomOperator(
1079 Operation* inst, mlir::TFL::CustomOp op,
1080 const std::vector<int32_t>& operands, const std::vector<int32_t>& results) {
1081 const std::string attrs =
1082 op.custom_option().cast<mlir::TFL::ConstBytesAttr>().getValue().str();
1083 std::vector<uint8_t> custom_option_vector(attrs.size());
1084 memcpy(custom_option_vector.data(), attrs.data(), attrs.size());
1085 auto opcode_index =
1086 GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM);
1087 return tflite::CreateOperator(
1088 builder_, opcode_index, builder_.CreateVector(operands),
1089 builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
1090 /*builtin_options=*/0,
1091 builder_.CreateVector<uint8_t>(custom_option_vector),
1092 tflite::CustomOptionsFormat_FLEXBUFFERS);
1093}
1094
1095Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions(
1096 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1097 std::string node_def_str;
1098 if (!node_def.SerializeToString(&node_def_str)) {
1099 return emitError(loc, "failed to serialize tensorflow node_def"),
1100 llvm::None;
1101 }
1102
1103 auto flex_builder = std::make_unique<flexbuffers::Builder>();
1104 flex_builder->Vector([&]() {
1105 flex_builder->String(node_def.op());
1106 flex_builder->String(node_def_str);
1107 });
1108 flex_builder->Finish();
1109 return builder_.CreateVector(flex_builder->GetBuffer());
1110}
1111
1112Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions(
1113 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1114 auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc);
1115 return builder_.CreateVector(flex_builder->GetBuffer());
1116}
1117
1118std::unique_ptr<flexbuffers::Builder>
1119Translator::CreateFlexBuilderWithNodeAttrs(
1120 const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) {
1121 auto flex_builder = std::make_unique<flexbuffers::Builder>();
1122 size_t map_start = flex_builder->StartMap();
1123 using Item = std::pair<std::string, ::tensorflow::AttrValue>;
1124 std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end());
1125 std::sort(attrs.begin(), attrs.end(),
1126 [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; });
1127 for (const Item& pair : attrs) {
1128 const char* key = pair.first.c_str();
1129 const ::tensorflow::AttrValue& attr = pair.second;
1130 switch (attr.value_case()) {
1131 case ::tensorflow::AttrValue::kS:
1132 flex_builder->String(key, attr.s());
1133 break;
1134 case ::tensorflow::AttrValue::kType: {
1135 auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type());
1136 if (status_or_tfl_type.ok()) {
1137 flex_builder->Int(key, status_or_tfl_type.value());
1138 } else {
1139 emitWarning(loc, "ignoring unsupported tensorflow type: ")
1140 << std::to_string(attr.type());
1141 }
1142 break;
1143 }
1144 case ::tensorflow::AttrValue::kI:
1145 flex_builder->Int(key, attr.i());
1146 break;
1147 case ::tensorflow::AttrValue::kF:
1148 flex_builder->Float(key, attr.f());
1149 break;
1150 case ::tensorflow::AttrValue::kB:
1151 flex_builder->Bool(key, attr.b());
1152 break;
1153 case tensorflow::AttrValue::kList:
1154 if (attr.list().s_size() > 0) {
1155 auto start = flex_builder->StartVector(key);
1156 for (const std::string& v : attr.list().s()) {
1157 flex_builder->Add(v);
1158 }
1159 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1160 } else if (attr.list().i_size() > 0) {
1161 auto start = flex_builder->StartVector(key);
1162 for (const int64_t v : attr.list().i()) {
1163 flex_builder->Add(v);
1164 }
1165 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1166 } else if (attr.list().f_size() > 0) {
1167 auto start = flex_builder->StartVector(key);
1168 for (const float v : attr.list().f()) {
1169 flex_builder->Add(v);
1170 }
1171 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
1172 } else {
1173 emitWarning(loc,
1174 "ignoring unsupported type in list attribute with key: ")
1175 << key;
1176 }
1177 break;
1178 default:
1179 emitWarning(loc, "ignoring unsupported attribute type with key: ")
1180 << key;
1181 break;
1182 }
1183 }
1184 flex_builder->EndMap(map_start);
1185 flex_builder->Finish();
1186 return flex_builder;
1187}
1188
1189uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
1190 tflite::BuiltinOperator builtin) {
1191 auto it = opcode_index_map_.insert({op_name, 0});
1192
1193 // If the insert succeeded, the opcode has not been created already. Create a
1194 // new operator code and update its index value in the map.
1195 if (it.second) {
1196 it.first->second = opcodes_.size();
1197 auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM
1198 ? builder_.CreateString(op_name)
1199 : BufferOffset<flatbuffers::String>();
1200 // Use version 0 for builtin op. This is a way to serialize version field to
1201 // flatbuffer (since 0 is non default) and it will be corrected later.
1202 int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1;
1203 opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin,
1204 custom_code, op_version));
1205 }
1206 return it.first->second;
1207}
1208
1209Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
1210 Operation* inst, std::vector<int32_t> operands,
1211 const std::vector<int32_t>& results,
1212 const std::vector<int32_t>& intermediates) {
1213 const auto* dialect = inst->getDialect();
1214 if (!dialect) {
1215 inst->emitOpError("dialect is not registered");
1216 return llvm::None;
1217 }
1218
1219 // If TFLite built in op, create operator as a builtin op.
1220 if (dialect == tfl_dialect_) {
1221 // Only if built-in TFLite op emission is enabled, would legalization have
1222 // converted any TF->TFL.
1223 if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) {
1224 return inst->emitOpError(
1225 "is a TFLite builtin op but builtin emission is not enabled"),
1226 llvm::None;
1227 }
1228
1229 auto builtin_code = GetBuiltinOpCode(inst);
1230 if (!builtin_code) {
1231 if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) {
1232 return BuildNumericVerifyOperator(verify_op, operands, results);
1233 }
1234 if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) {
1235 return BuildCustomOperator(inst, custom_op, operands, results);
1236 }
1237 if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) {
1238 if (inst->getNumOperands() != inst->getNumResults()) {
1239 inst->emitOpError(
1240 "number of operands and results don't match, only canonical "
1241 "TFL While supported");
1242 return llvm::None;
1243 }
1244 return BuildWhileOperator(whileOp, operands, results);
1245 }
1246
1247 inst->emitOpError("is not a supported TFLite op");
1248 return llvm::None;
1249 }
1250
1251 if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
1252 if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
1253 return BuildCallOnceOperator(initOp, operands, results);
1254 }
1255 }
1256
1257 std::string op_name = inst->getName().getStringRef().str();
1258 uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);
1259
1260 // If this is TransposeConv we need to do a special case of ignoring the
1261 // optional tensor, to allow newly created models to run on old runtimes.
1262 if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) {
1263 if (operands.size() == 4 && operands.at(3) == -1) {
1264 operands.pop_back();
1265 }
1266 }
1267
1268 auto offset = CreateFlatBufferOperator(inst, opcode_index, operands,
1269 results, intermediates, &builder_);
1270 if (!offset) {
1271 inst->emitOpError("is not a supported TFLite op");
1272 }
1273 return offset;
1274 }
1275
1276 if (dialect == tf_dialect_) {
1277 if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) {
1278 return BuildIfOperator(ifOp, operands, results);
1279 }
1280
1281 CustomOptionsOffset custom_options;
1282
1283 // Ops in TF dialect can either be custom ops or flex ops.
1284 // The reason we go directly from TensorFlow dialect MLIR to tensorflow
1285 // node instead of going to TF table gen'd ops via generated code is that
1286 // we do not want to restrict custom and flex op conversion support to
1287 // only those TF ops that are currently registered in MLIR. The current
1288 // model is of an open op system.
1289 //
1290 // The following algorithm is followed:
1291 // if flex is enabled and the op is allowlisted as flex
1292 // we emit op as flex.
1293 // if custom is enabled
1294 // we emit the op as custom.
1295 auto node_def = GetTensorFlowNodeDef(inst);
1296 if (!node_def) {
1297 return llvm::None;
1298 }
1299
1300 std::string op_name = node_def->op();
1301 std::string op_desc = GetOpDescriptionForDebug(inst);
1302
1303 if (IsTFResourceOp(inst)) {
1304 resource_ops_[op_name].insert(op_desc);
1305 }
1306
1307 const bool is_allowed_flex_op =
1308 !IsUnsupportedFlexOp(node_def->op()) &&
1309 (IsAllowlistedFlexOp(node_def->op()) ||
1310 (((select_user_tf_ops_.count(node_def->op()) != 0) ||
1311 allow_all_select_tf_ops_) &&
1312 (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) !=
1313 nullptr)));
1314
1315 // Flex op case
1316 // Eventually, the allowlist will go away and we will rely on some TF op
1317 // trait (e.g. No side effect) to determine if it is a supported "Flex"
1318 // op or not.
1319 if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) {
1320 // Construct ops as flex op encoding TensorFlow node definition
1321 // as custom options.
1322 // Flex ops are named with the kFlexOpNamePrefix prefix to the actual
1323 // TF op name.
1324 op_name = std::string(kFlexOpNamePrefix) + node_def->op();
1325 if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) {
1326 custom_options = *options;
1327 } else {
1328 return llvm::None;
1329 }
1330
1331 // Gather flex ops.
1332 flex_ops_[op_name].insert(op_desc);
1333 } else if (enabled_op_types_.contains(OpType::kCustomOp)) {
1334 // Generic case of custom ops - write using flex buffers since that
1335 // is the only custom options supported by TFLite today.
1336 op_name = node_def->op();
1337 if (auto options =
1338 CreateCustomOpCustomOptions(*node_def, inst->getLoc())) {
1339 custom_options = *options;
1340 } else {
1341 return llvm::None;
1342 }
1343
1344 // Gather custom ops.
1345 custom_ops_[op_name].insert(op_desc);
1346 } else {
1347 // Insert failed op to `flex_ops` or `custom_ops`.
1348 if (is_allowed_flex_op) {
1349 failed_flex_ops_[op_name].insert(op_desc);
1350 tfl::AttachErrorCode(
1351 inst->emitOpError("is neither a custom op nor a flex op"),
1352 tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS);
1353 } else {
1354 failed_custom_ops_[op_name].insert(op_desc);
1355 tfl::AttachErrorCode(
1356 inst->emitOpError("is neither a custom op nor a flex op"),
1357 tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS);
1358 }
1359 return llvm::None;
1360 }
1361
1362 uint32_t opcode_index =
1363 GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM);
1364 auto inputs = builder_.CreateVector(operands);
1365 auto outputs = builder_.CreateVector(results);
1366
1367 return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
1368 tflite::BuiltinOptions_NONE,
1369 /*builtin_options=*/0,
1370 /*custom_options=*/custom_options,
1371 tflite::CustomOptionsFormat_FLEXBUFFERS,
1372 /*mutating_variable_inputs=*/0);
1373 }
1374
1375 return inst->emitOpError(
1376 "is not any of a builtin TFLite op, a flex TensorFlow op or a "
1377 "custom TensorFlow op"),
1378 llvm::None;
1379}
1380
1381void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) {
1382 auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1383 if (!dict_attr) return;
1384
1385 llvm::SmallVector<llvm::StringRef, 2> input_names;
1386 llvm::SmallVector<llvm::StringRef, 2> output_names;
1387 if (auto str = dict_attr.get("inputs").dyn_cast_or_null<mlir::StringAttr>()) {
1388 str.getValue().split(input_names, ',', /*MaxSplit=*/-1,
1389 /*KeepEmpty=*/false);
1390 if (input_names.size() != fn.getNumArguments()) {
1391 fn.emitWarning() << "invalid entry function specification";
1392 return;
1393 }
1394 for (const auto& it : llvm::enumerate(fn.getArguments())) {
1395 name_mapper_.InitOpName(it.value(), input_names[it.index()].trim());
1396 }
1397 *has_input_attr = true;
1398 }
1399
1400 if (auto str =
1401 dict_attr.get("outputs").dyn_cast_or_null<mlir::StringAttr>()) {
1402 str.getValue().split(output_names, ',', /*MaxSplit=*/-1,
1403 /*KeepEmpty=*/false);
1404 auto term = fn.back().getTerminator();
1405 if (output_names.size() != term->getNumOperands()) {
1406 fn.emitWarning() << "output names (" << output_names.size()
1407 << ") != terminator operands (" << term->getNumOperands()
1408 << ")";
1409 return;
1410 }
1411 for (const auto& it : llvm::enumerate(term->getOperands())) {
1412 name_mapper_.InitOpName(it.value(), output_names[it.index()].trim());
1413 }
1414 }
1415}
1416
1417bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) {
1418 std::vector<int> operand_indices;
1419 if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false;
1420 return absl::c_find(operand_indices, operand_index) != operand_indices.end();
1421}
1422
1423BufferOffset<tflite::QuantizationParameters>
1424Translator::GetQuantizationForQuantStatsOpOutput(
1425 mlir::quantfork::StatisticsOp stats_op) {
1426 auto layer_stats = stats_op.getLayerStats().cast<mlir::DenseFPElementsAttr>();
1427 Optional<mlir::ElementsAttr> axis_stats = stats_op.getAxisStats();
1428 Optional<uint64_t> axis = stats_op.getAxis();
1429 std::vector<float> mins, maxs;
1430 mlir::DenseFPElementsAttr min_max_attr =
1431 axis_stats.has_value()
1432 ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>()
1433 : layer_stats;
1434
1435 for (const auto& index_and_value :
1436 llvm::enumerate(min_max_attr.getValues<llvm::APFloat>())) {
1437 const llvm::APFloat value = index_and_value.value();
1438 if (index_and_value.index() % 2 == 0) {
1439 mins.push_back(value.convertToFloat());
1440 } else {
1441 maxs.push_back(value.convertToFloat());
1442 }
1443 }
1444
1445 return tflite::CreateQuantizationParameters(
1446 builder_, builder_.CreateVector<float>(mins),
1447 builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0,
1448 tflite::QuantizationDetails_NONE, /*details=*/0,
1449 /*quantized_dimension=*/axis.has_value() ? axis.getValue() : 0);
1450}
1451
1452Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
1453 const std::string& name, Region* region, const int index) {
1454 const auto control_edges = ExtractControlEdges(&region->front());
1455 bool has_input_attr = false;
1456 if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) {
1457 InitializeNamesFromAttribute(fn, &has_input_attr);
1458 }
1459 std::vector<BufferOffset<tflite::Tensor>> tensors;
1460 llvm::DenseMap<Value, int> tensor_index_map;
1461
1462 // Builds tensor and buffer for argument or operation result. Returns false
1463 // on failure.
1464 auto build_tensor_and_buffer = [&](Value value, const int subgraph_index,
1465 const std::string& tensor_name) {
1466 // NoneType represents optional and may be skipped here.
1467 if (value.getType().isa<NoneType>()) {
1468 return true;
1469 }
1470
1471 tensor_index_map.insert({value, tensors.size()});
1472 tensor_index_map_[subgraph_index][tensor_name] = tensors.size();
1473 Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters;
1474 if (value.hasOneUse()) {
1475 auto stats_op =
1476 llvm::dyn_cast<mlir::quantfork::StatisticsOp>(*value.user_begin());
1477 if (stats_op) {
1478 quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op);
1479 }
1480 }
1481 auto tensor_or =
1482 BuildTensor(value, tensor_name, buffers_.size(), quant_parameters);
1483 if (!tensor_or) return false;
1484 tensors.push_back(*tensor_or);
1485
1486 // TODO(ashwinm): Check if for stateful tensors, if it is also needed to
1487 // make the Buffer empty apart from setting the buffer_idx=0 in the
1488 // Tensor. This does not seem to affect runtime behavior for RNN/LSTM,
1489 // but would be good for reducing memory footprint.
1490 if (auto* inst = value.getDefiningOp()) {
1491 auto buffer_or = BuildBuffer(inst);
1492 if (!buffer_or) return false;
1493 buffers_.push_back(*buffer_or);
1494 } else {
1495 buffers_.push_back(empty_buffer_);
1496 }
1497 return true;
1498 };
1499
1500 std::vector<BufferOffset<tflite::Operator>> operators;
1501
1502 // Maps positions of operations in bb to positions in operators
1503 llvm::DenseMap<int, int> operation_index_to_operator_index;
1504 std::vector<Operation*> operators_in_mlir;
1505 auto& bb = region->front();
1506
1507 // Main function's arguments are first passed to `input` op so they don't
1508 // have associated tensor and buffer. Build FlatBuffer tensor and buffer for
1509 // other functions.
1510 for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) {
1511 mlir::BlockArgument arg = bb.getArgument(i);
1512 std::string tensor_name;
1513 if (has_input_attr)
1514 tensor_name = std::string(name_mapper_.GetUniqueName(arg));
1515 if (tensor_name.empty()) tensor_name = absl::StrCat("arg", i);
1516 if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None;
1517 }
1518
1519 bool failed_once = false;
1520 for (auto& item : llvm::enumerate(bb)) {
1521 Operation& inst = item.value();
1522 const int operation_index = item.index();
1523 if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break;
1524 // For "quant.stats" op, it's used to store the quantization parameters info
1525 // and its output should be then replaced by its input value.
1526 if (auto quant_stats_op =
1527 llvm::dyn_cast<mlir::quantfork::StatisticsOp>(inst)) {
1528 continue;
1529 }
1530 std::vector<int32_t> intermediates;
1531 // Build intermediate tensors for tfl.lstm and insert these tensors into
1532 // flatbuffer.
1533 if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>(
1534 inst)) {
1535 std::vector<std::string> intermediate_names = {
1536 "input_to_input_intermediate", "input_to_forget_intermediate",
1537 "input_to_cell_intermediate", "input_to_output_intermediate",
1538 "effective_hidden_scale_intermediate"};
1539 for (const std::string& intermediate : intermediate_names) {
1540 auto intermediate_attr = inst.getAttr(intermediate);
1541 if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) {
1542 Type qtype = attr.getValue();
1543 auto tensor_or = BuildTensorFromType(
1544 qtype, name_mapper_.GetUniqueName(intermediate).str());
1545 if (!tensor_or.has_value()) {
1546 continue;
1547 } else {
1548 intermediates.push_back(tensors.size());
1549 tensors.push_back(tensor_or.getValue());
1550 }
1551 }
1552 }
1553 }
1554
1555 for (auto val : inst.getResults()) {
1556 std::string tensor_name = UniqueName(val);
1557 // For "tfl.numeric_verify" op, the name is used to find out the original
1558 // activation tensor rather than its own unique name in the visualization
1559 // or debugging tools.
1560 auto builtin_code = GetBuiltinOpCode(&inst);
1561 if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) {
1562 // The first operand is the quantized activation, the target of this
1563 // NumericVerify op.
1564 auto quantized_op_val = inst.getOperands().front();
1565 tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" +
1566 std::to_string(tensor_index_map[quantized_op_val]);
1567 }
1568 if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None;
1569 }
1570
1571 // Skip constant ops as they don't represent a TFLite operator.
1572 if (IsConst(&inst)) continue;
1573
1574 // Fetch operand and result tensor indices.
1575 std::vector<int32_t> results;
1576 results.reserve(inst.getNumResults());
1577 for (auto result : inst.getResults()) {
1578 results.push_back(tensor_index_map.lookup(result));
1579 }
1580 Operation* real_inst = &inst;
1581 std::vector<int32_t> operands;
1582 operands.reserve(real_inst->getNumOperands());
1583 for (auto operand : real_inst->getOperands()) {
1584 if (operand.getType().isa<NoneType>())
1585 operands.push_back(kTfLiteOptionalTensor);
1586 else if (auto stats_op =
1587 llvm::dyn_cast_or_null<mlir::quantfork::StatisticsOp>(
1588 operand.getDefiningOp()))
1589 operands.push_back(tensor_index_map.lookup(stats_op.getArg()));
1590 else
1591 operands.push_back(tensor_index_map.lookup(operand));
1592 }
1593
1594 // CustomTfOp is just a wrapper around a TF op, we export the custom Op
1595 // not the wrapper, so we fetch the op from the region.
1596 if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
1597 // If we have custom op with a region, then use the first op in the
1598 // region, if it exists, otherwise just use params for custom op.
1599 if (!custom_op.body().empty()) {
1600 real_inst = &custom_op.body().front().front();
1601 } else {
1602 module_.emitError(
1603 "Invalid CustomTfOp: Custom TF Op have empty region.");
1604 }
1605 }
1606 if (auto tfl_operator =
1607 BuildOperator(real_inst, operands, results, intermediates)) {
1608 operation_index_to_operator_index.try_emplace(operation_index,
1609 operators.size());
1610 operators.push_back(*tfl_operator);
1611 operators_in_mlir.push_back(real_inst);
1612 } else {
1613 failed_once = true;
1614 }
1615 }
1616 if (index + 1 > subgraph_op_inst_map_.size()) {
1617 subgraph_op_inst_map_.resize(index + 1);
1618 }
1619 subgraph_op_inst_map_[index] = operators_in_mlir;
1620 if (failed_once) return llvm::None;
1621
1622 // Get input and output tensor indices for the subgraph.
1623 std::vector<int32_t> inputs, outputs;
1624 for (auto arg : bb.getArguments()) {
1625 inputs.push_back(tensor_index_map[arg]);
1626 }
1627 for (auto result : bb.getTerminator()->getOperands()) {
1628 outputs.push_back(tensor_index_map[result]);
1629 }
1630 for (const auto& [from, to] : control_edges) {
1631 for (int what : {from, to}) {
1632 if (operation_index_to_operator_index.count(what) == 0) {
1633 module_.emitError(
1634 "dangling control edge -- at least one vertex Operation isn't a "
1635 "flatbuffer Operator.");
1636 }
1637 }
1638 model_control_dependencies_[index].emplace_back(
1639 operation_index_to_operator_index[from],
1640 operation_index_to_operator_index[to]);
1641 }
1642 return tflite::CreateSubGraph(
1643 builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs),
1644 builder_.CreateVector(outputs), builder_.CreateVector(operators),
1645 /*name=*/builder_.CreateString(name));
1646}
1647
1648BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name,
1649 StringRef content) {
1650 auto buffer_index = buffers_.size();
1651 auto buffer_data = builder_.CreateVector(
1652 reinterpret_cast<const uint8_t*>(content.data()), content.size());
1653 buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data));
1654 return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index);
1655}
1656
1657Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>>
1658Translator::CreateMetadataVector() {
1659 auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata");
1660 std::vector<BufferOffset<tflite::Metadata>> metadata;
1661 if (dict_attr) {
1662 for (const auto& named_attr : dict_attr) {
1663 StringRef name = named_attr.getName();
1664 mlir::Attribute attr = named_attr.getValue();
1665 if (auto content = attr.dyn_cast<StringAttr>()) {
1666 metadata.push_back(BuildMetadata(name, content.getValue()));
1667 } else {
1668 module_.emitError(
1669 "all values in tfl.metadata's dictionary key-value pairs should be "
1670 "string attributes");
1671 return llvm::None;
1672 }
1673 }
1674 }
1675 // Runtime version string is generated after we update the op
1676 // versions. Here we put a 16-byte dummy string as a placeholder. We choose
1677 // 16-byte because it's the alignment of buffers in flatbuffer, so it won't
1678 // cause any waste of space if the actual string is shorter than 16 bytes.
1679 constexpr std::size_t kByteStringSize = 16;
1680 metadata.push_back(
1681 BuildMetadata("min_runtime_version", std::string(kByteStringSize, '\0')));
1682 for (const auto& kv : metadata_) {
1683 const std::string& val = kv.second;
1684 // Only take the first kByteStringSize values.
1685 const int count = std::min(kByteStringSize, val.length());
1686 std::string value = std::string(kByteStringSize, '\0')
1687 .assign(val.begin(), val.begin() + count);
1688 metadata.push_back(BuildMetadata(kv.first, value));
1689 }
1690
1691 // Populate the model control dependencies metadata entry.
1692 if (std::any_of(
1693 model_control_dependencies_.begin(),
1694 model_control_dependencies_.end(),
1695 [](const tflite::ControlEdges& edges) { return !edges.empty(); })) {
1696 metadata.push_back(
1697 BuildMetadata(tflite::kModelControlDependenciesMetadataKey,
1698 tflite::SerializeModelControlDependencies(
1699 model_control_dependencies_)));
1700 }
1701 return builder_.CreateVector(metadata);
1702}
1703
1704// Helper method that returns list of all strings in a StringAttr identified
1705// by 'attr_key' and values are separated by a comma.
1706llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator(
1707 mlir::DictionaryAttr attr, const std::string& attr_key) {
1708 llvm::SmallVector<llvm::StringRef, 2> result;
1709 if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) {
1710 str.getValue().split(result, ',', /*MaxSplit=*/-1,
1711 /*KeepEmpty=*/false);
1712 }
1713 return result;
1714}
1715
1716// Helper method that return list of string for all the StringAttr in the
1717// Attribute identified by 'attr_name'.
1718std::vector<std::string> GetStringsFromDictionaryAttr(
1719 const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs,
1720 const std::string& attr_name) {
1721 std::vector<std::string> result;
1722 for (const auto& arg_attr : dict_attrs) {
1723 if (!arg_attr) continue;
1724
1725 auto attrs = arg_attr.getValue();
1726 for (const auto attr : attrs) {
1727 if (attr.getName().str() == attr_name) {
1728 auto array_attr = attr.getValue().dyn_cast_or_null<mlir::ArrayAttr>();
1729 if (!array_attr || array_attr.empty()) continue;
1730 auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>();
1731 if (!string_attr) continue;
1732 result.push_back(string_attr.getValue().str());
1733 }
1734 }
1735 }
1736 return result;
1737}
1738
1739std::vector<SignatureDefData> BuildSignaturedef(
1740 FuncOp main_op, const std::string& saved_model_tag,
1741 const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) {
1742 static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path";
1743 static const char kEntryFunctionAttributes[] = "tf.entry_function";
1744
1745 // Fetch inputs and outputs from the signature.
1746 llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs;
1747 main_op.getAllArgAttrs(arg_attrs);
1748 main_op.getAllResultAttrs(res_attrs);
1749 std::vector<std::string> sig_def_inputs =
1750 GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath);
1751 std::vector<std::string> sig_def_outputs =
1752 GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath);
1753
1754 // If no defined saved model signature, then return empty list.
1755 // This can happen when we are converting model not from SavedModel.
1756 if (sig_def_inputs.empty() && sig_def_outputs.empty()) return {};
1757
1758 // Fetch function inputs and outputs tensor names.
1759 auto dict_attr =
1760 main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes);
1761 if (!dict_attr) return {};
1762
1763 // Get Input and output tensor names from attribute.
1764 llvm::SmallVector<llvm::StringRef, 2> input_names =
1765 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs");
1766 llvm::SmallVector<llvm::StringRef, 2> output_names =
1767 GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs");
1768
1769 // Verify input size match the number of arguments.
1770 if (input_names.size() != main_op.getNumArguments()) {
1771 main_op.emitWarning() << "invalid entry function specification";
1772 return {};
1773 }
1774 // Verify output size match the number of arguments.
1775 auto term = main_op.back().getTerminator();
1776 if (output_names.size() != term->getNumOperands()) {
1777 main_op.emitWarning() << "output names (" << output_names.size()
1778 << ") != terminator operands ("
1779 << term->getNumOperands() << ")";
1780 return {};
1781 }
1782 // Verify number of tensors for inputs and outputs matches size
1783 // of the list in the signature def.
1784 if (input_names.size() != sig_def_inputs.size() ||
1785 output_names.size() != sig_def_outputs.size()) {
1786 main_op.emitWarning(
1787 "Mismatch between signature def inputs/outputs and main function "
1788 "arguments.");
1789 return {};
1790 }
1791 // Exported method name.
1792 auto exported_name =
1793 main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names");
1794 if (exported_name.empty()) {
1795 main_op.emitError("Empty exported names for main Function");
1796 return {};
1797 }
1798 // Fill the SignatureDefData container.
1799 // We create vector of size 1 as TFLite now supports only 1 signatureDef.
1800 std::vector<SignatureDefData> result(1);
1801 for (int i = 0; i < input_names.size(); ++i) {
1802 result[0].inputs[sig_def_inputs[i]] = input_names[i].str();
1803 }
1804 for (int i = 0; i < output_names.size(); ++i) {
1805 // Fetch the name from the actual operand and not rely on names from
1806 // outputs as deduping can make them invalid after conversion.
1807 auto& operand = term->getOpOperand(i);
1808 auto unique_name = std::string(name_mapper.GetUniqueName(operand.get()));
1809 result[0].outputs[sig_def_outputs[i]] = unique_name;
1810 }
1811 if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>())
1812 result[0].signature_key = name_attr.getValue().str();
1813 result[0].subgraph_index = subgraph_index;
1814 return result;
1815}
1816
1817std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList(
1818 const int subgraph_index, const std::map<std::string, std::string>& items) {
1819 std::vector<BufferOffset<tflite::TensorMap>> result;
1820 for (const auto& item : items) {
1821 auto name_buf = builder_.CreateString(item.first);
1822 tflite::TensorMapBuilder tensor_map_builder(builder_);
1823 tensor_map_builder.add_name(name_buf);
1824 tensor_map_builder.add_tensor_index(
1825 tensor_index_map_[subgraph_index][item.second]);
1826 result.push_back(tensor_map_builder.Finish());
1827 }
1828 return result;
1829}
1830
1831Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>>
1832Translator::CreateSignatureDefs(
1833 const std::vector<SignatureDefData>& signature_defs) {
1834 std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer;
1835 // When we export each function in the module op, intentionally, we export the
1836 // entry functions at the beginning of the subgraph list and the
1837 // subgraph_index is the index in entry functions and at the same, is the
1838 // index in the subgraph list.
1839 int subgraph_index = 0;
1840 for (const auto& signature_def_data : signature_defs) {
1841 auto inputs = GetList(subgraph_index, signature_def_data.inputs);
1842 auto outputs = GetList(subgraph_index, signature_def_data.outputs);
1843 auto inputs_buf = builder_.CreateVector(inputs);
1844 auto outputs_buf = builder_.CreateVector(outputs);
1845 auto signature_key_buf =
1846 builder_.CreateString(signature_def_data.signature_key);
1847 tflite::SignatureDefBuilder sig_def_builder(builder_);
1848 sig_def_builder.add_inputs(inputs_buf);
1849 sig_def_builder.add_outputs(outputs_buf);
1850 sig_def_builder.add_signature_key(signature_key_buf);
1851 sig_def_builder.add_subgraph_index(signature_def_data.subgraph_index);
1852 signature_defs_buffer.push_back(sig_def_builder.Finish());
1853 ++subgraph_index;
1854 }
1855
1856 return builder_.CreateVector(signature_defs_buffer);
1857}
1858
1859bool UpdateEntryFunction(ModuleOp module) {
1860 if (module.lookupSymbol<FuncOp>("main") != nullptr) {
1861 // We already have an entry function.
1862 return true;
1863 }
1864
1865 int entry_func_count = 0;
1866 FuncOp entry_func = nullptr;
1867 for (auto fn : module.getOps<FuncOp>()) {
1868 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1869 if (!attrs || attrs.empty()) continue;
1870 ++entry_func_count;
1871 entry_func = fn;
1872 }
1873
1874 // We should have at least one entry function.
1875 if (entry_func_count == 0) return false;
1876
1877 if (entry_func_count == 1) {
1878 // Update the entry func to main when the entry func is only & one.
1879 entry_func.setName(StringAttr::get(module.getContext(), "main"));
1880 }
1881 return true;
1882}
1883
1884Optional<std::string> Translator::Translate(
1885 ModuleOp module, const toco::TocoFlags& toco_flags,
1886 const std::unordered_set<std::string>& tags,
1887 OpOrArgNameMapper* op_or_arg_name_mapper,
1888 const std::map<std::string, std::string>& metadata) {
1889 OpOrArgLocNameMapper default_op_or_arg_name_mapper;
1890 if (!op_or_arg_name_mapper)
1891 op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
1892 if (!UpdateEntryFunction(module)) return llvm::None;
1893 if (!IsValidTFLiteMlirModule(module)) return llvm::None;
1894 Translator translator(module, toco_flags, tags, op_or_arg_name_mapper,
1895 metadata);
1896 return translator.TranslateInternal();
1897}
1898
1899bool Translator::CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer) {
1900 bool gpu_compatibile = true;
1901 auto model = tflite::GetModel(model_buffer_pointer);
1902 auto subgraphs = model->subgraphs();
1903
1904 for (int i = 0; i < subgraphs->Length(); ++i) {
1905 const tflite::SubGraph* subgraph = subgraphs->Get(i);
1906 for (int j = 0; j < subgraph->operators()->Length(); ++j) {
1907 const tflite::Operator* op = subgraph->operators()->Get(j);
1908 const tflite::OperatorCode* op_code =
1909 model->operator_codes()->Get(op->opcode_index());
1910 auto status =
1911 tflite::CheckGpuDelegateCompatibility(op_code, op, subgraph, model);
1912 if (!status.ok()) {
1913 gpu_compatibile = false;
1914 auto inst = subgraph_op_inst_map_[i][j];
1915 tfl::AttachErrorCode(
1916 inst->emitOpError()
1917 << "is not GPU compatible: " << std::string(status.message()),
1918 tflite::metrics::ConverterErrorData::ERROR_GPU_NOT_COMPATIBLE);
1919 }
1920 }
1921 }
1922 return gpu_compatibile;
1923}
1924
1925Optional<std::string> Translator::TranslateInternal() {
1926 // A list of named regions in the module with main function being the first in
1927 // the list. The main function is required as the first subgraph in the model
1928 // is entry point for the model.
1929 std::vector<std::pair<std::string, Region*>> named_regions;
1930 named_regions.reserve(std::distance(module_.begin(), module_.end()));
1931
1932 int subgraph_idx = 0;
1933
1934 // Entry functions for signature defs.
1935 std::vector<FuncOp> entry_functions;
1936 std::vector<FuncOp> non_entry_functions;
1937 FuncOp main_fn = module_.lookupSymbol<FuncOp>("main");
1938 if (main_fn != nullptr) {
1939 // Treat the main function as a signature def when the given main function
1940 // contains on the tf.entry_function attribute.
1941 auto attrs =
1942 main_fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1943 if (attrs && !attrs.empty()) {
1944 entry_functions.push_back(main_fn);
1945 } else {
1946 non_entry_functions.push_back(main_fn);
1947 }
1948 }
1949
1950 // Walk over the module collection ops with functions and while ops.
1951 module_.walk([&](FuncOp fn) {
1952 if (main_fn == fn) return WalkResult::advance();
1953 auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
1954 if (attrs && !attrs.empty()) {
1955 entry_functions.push_back(fn);
1956 } else {
1957 non_entry_functions.push_back(fn);
1958 }
1959 return WalkResult::advance();
1960 });
1961
1962 // Assign the subgraph index. Among the given functions, it will put entry
1963 // functions at the beginning of the list of the subgrahs.
1964 for (auto fn : entry_functions) {
1965 subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1966 named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1967 }
1968 for (auto fn : non_entry_functions) {
1969 subgraph_index_map_[fn.getName().str()] = subgraph_idx++;
1970 named_regions.emplace_back(fn.getName().str(), &fn.getBody());
1971 }
1972
1973 // Build subgraph for each of the named regions.
1974 std::vector<BufferOffset<tflite::SubGraph>> subgraphs;
1975 subgraphs.reserve(named_regions.size());
1976 model_control_dependencies_.assign(named_regions.size(), {});
1977 int first_failed_func = -1;
1978
1979 // When we export each function in the module op, intentionally, we export the
1980 // entry functions at the beginning of the subgraph list and the
1981 // subgraph_index is the index in entry functions and at the same, is the
1982 // index in the subgraph list.
1983 int subgraph_index = 0;
1984 for (const auto& it : llvm::enumerate(named_regions)) {
1985 auto subgraph_or =
1986 BuildSubGraph(it.value().first, it.value().second, subgraph_index);
1987 if (!subgraph_or) {
1988 if (first_failed_func == -1)
1989 // Record the index of the first region that cannot be converted.
1990 // Keep looping through all subgraphs in the module to make sure that
1991 // we collect the list of missing ops from the entire module.
1992 first_failed_func = it.index();
1993 } else {
1994 subgraphs.push_back(*subgraph_or);
1995 ++subgraph_index;
1996 }
1997 }
1998
1999 if (!resource_ops_.empty()) {
2000 std::string resource_ops_summary =
2001 GetOpsSummary(resource_ops_, /*summary_title=*/"Resource");
2002 LOG(WARNING) << "Graph contains the following resource op(s), that use(s) "
2003 "resource type. Currently, the "
2004 "resource type is not natively supported in TFLite. Please "
2005 "consider not using the resource type if there are issues "
2006 "with either TFLite converter or TFLite runtime:\n"
2007 << resource_ops_summary;
2008 }
2009
2010 if (!flex_ops_.empty()) {
2011 std::string flex_ops_summary =
2012 GetOpsSummary(flex_ops_, /*summary_title=*/"Flex");
2013 LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order "
2014 "to run the model since it contains the following Select TF"
2015 "op(s):\n"
2016 << flex_ops_summary
2017 << "\nSee instructions: "
2018 "https://www.tensorflow.org/lite/guide/ops_select";
2019 }
2020
2021 if (!custom_ops_.empty()) {
2022 std::string custom_ops_summary =
2023 GetOpsSummary(custom_ops_, /*summary_title=*/"Custom");
2024 LOG(WARNING) << "The following operation(s) need TFLite custom op "
2025 "implementation(s):\n"
2026 << custom_ops_summary
2027 << "\nSee instructions: "
2028 "https://www.tensorflow.org/lite/guide/ops_custom";
2029 }
2030
2031 if (first_failed_func != -1) {
2032 std::string failed_flex_ops_summary =
2033 GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select");
2034 std::string failed_custom_ops_summary =
2035 GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom");
2036 std::string err;
2037 if (!failed_flex_ops_.empty())
2038 err +=
2039 "\nSome ops are not supported by the native TFLite runtime, you can "
2040 "enable TF kernels fallback using TF Select. See instructions: "
2041 "https://www.tensorflow.org/lite/guide/ops_select \n" +
2042 failed_flex_ops_summary + "\n";
2043 if (!failed_custom_ops_.empty())
2044 err +=
2045 "\nSome ops in the model are custom ops, "
2046 "See instructions to implement "
2047 "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" +
2048 failed_custom_ops_summary + "\n";
2049
2050 auto& failed_region = named_regions[first_failed_func];
2051 return failed_region.second->getParentOp()->emitError()
2052 << "failed while converting: '" << failed_region.first
2053 << "': " << err,
2054 llvm::None;
2055 }
2056
2057 // Log MAC count.
2058 int64_t ops_count;
2059 if (EstimateArithmeticCount(&ops_count)) {
2060 const int64_t million = 1e6;
2061 const int64_t billion = 1e9;
2062 std::string flops_str;
2063 std::string mac_str;
2064 if (ops_count < 10000) {
2065 flops_str = absl::StrFormat("%ld ", ops_count);
2066 mac_str = absl::StrFormat("%ld ", ops_count / 2);
2067 } else if (ops_count < billion) {
2068 flops_str =
2069 absl::StrFormat("%.3f M ", static_cast<double>(ops_count) / million);
2070 mac_str = absl::StrFormat("%.3f M ",
2071 static_cast<double>(ops_count / 2) / million);
2072 } else {
2073 flops_str =
2074 absl::StrFormat("%.3f G ", static_cast<double>(ops_count) / billion);
2075 mac_str = absl::StrFormat("%.3f G ",
2076 static_cast<double>(ops_count / 2) / billion);
2077 }
2078 LOG(INFO) << "Estimated count of arithmetic ops: " << flops_str
2079 << " ops, equivalently " << mac_str << " MACs";
2080 }
2081
2082 std::string model_description;
2083 if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description")) {
2084 model_description = attr.getValue().str();
2085 } else {
2086 model_description = "MLIR Converted.";
2087 }
2088
2089 // Build the model and finish the model building process.
2090 auto description = builder_.CreateString(model_description.data());
2091 VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated
2092 auto metadata = CreateMetadataVector();
2093 if (!metadata) return llvm::None;
2094
2095 std::vector<SignatureDefData> signature_defs_vec;
2096 subgraph_index = 0;
2097 // Build SignatureDefs for the tf.entry_function based func ops.
2098 for (auto fn : entry_functions) {
2099 auto signature_defs = BuildSignaturedef(
2100 fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin(),
2101 subgraph_index, name_mapper_);
2102 for (const auto& signature_def : signature_defs) {
2103 signature_defs_vec.push_back(signature_def);
2104 }
2105 // When we export each function in the module op, intentionally, we export
2106 // the entry functions at the beginning of the subgraph list and the
2107 // subgraph_index is the index in entry functions and at the same, is the
2108 // index in the subgraph list.
2109 ++subgraph_index;
2110 }
2111 auto signature_defs = CreateSignatureDefs(signature_defs_vec);
2112
2113 auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION,
2114 builder_.CreateVector(opcodes_),
2115 builder_.CreateVector(subgraphs),
2116 description, builder_.CreateVector(buffers_),
2117 metadata_buffer, *metadata, *signature_defs);
2118 tflite::FinishModelBuffer(builder_, model);
2119 // There is a limit of 2GB for a flatbuffer.
2120 if (builder_.GetSize() > 2147483648) {
2121 LOG(ERROR) << "Model size is bigger than 2gb";
2122 return llvm::None;
2123 }
2124 tflite::UpdateOpVersion(builder_.GetBufferPointer());
2125 tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer());
2126 if (supported_backends_.find("GPU") != supported_backends_.end()) {
2127 if (!CheckGpuDelegateCompatibility(builder_.GetBufferPointer())) {
2128 return llvm::None;
2129 }
2130 }
2131
2132 // Return serialized string for the built FlatBuffer.
2133 return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
2134 builder_.GetSize());
2135}
2136
2137BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
2138 const mlir::TFL::SparsityParameterAttr& s_attr) {
2139 const int dim_size = s_attr.getDimMetadata().size();
2140 std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata(
2141 dim_size);
2142 for (int i = 0; i < dim_size; i++) {
2143 const auto dim_metadata =
2144 s_attr.getDimMetadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>();
2145 if (dim_metadata.getFormat().getValue() ==
2146 mlir::TFL::DimensionType::DENSE) {
2147 fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2148 builder_, tflite::DimensionType_DENSE, dim_metadata.getDenseSize());
2149
2150 } else {
2151 auto segments = dim_metadata.getSegments();
2152 std::vector<int> vector_segments(segments.size(), 0);
2153 for (int j = 0, end = segments.size(); j < end; j++) {
2154 vector_segments[j] = segments[j];
2155 }
2156 tflite::SparseIndexVector segments_type;
2157 BufferOffset<void> array_segments;
2158 // The segment array is sorted.
2159 // TODO(b/147449640): Clean this up with util functions.
2160 int max_of_segments = vector_segments[segments.size() - 1];
2161 if (max_of_segments <= UINT8_MAX) {
2162 segments_type = tflite::SparseIndexVector_Uint8Vector;
2163 std::vector<uint8_t> uint8_vector(vector_segments.begin(),
2164 vector_segments.end());
2165 array_segments = tflite::CreateUint8Vector(
2166 builder_, builder_.CreateVector(uint8_vector))
2167 .Union();
2168 } else if (max_of_segments <= UINT16_MAX) {
2169 segments_type = tflite::SparseIndexVector_Uint16Vector;
2170 std::vector<uint16_t> uint16_vector(vector_segments.begin(),
2171 vector_segments.end());
2172 array_segments = tflite::CreateUint16Vector(
2173 builder_, builder_.CreateVector(uint16_vector))
2174 .Union();
2175 } else {
2176 segments_type = tflite::SparseIndexVector_Int32Vector;
2177 array_segments = tflite::CreateInt32Vector(
2178 builder_, builder_.CreateVector(vector_segments))
2179 .Union();
2180 }
2181
2182 auto indices = dim_metadata.getIndices();
2183 std::vector<int> vector_indices(indices.size(), 0);
2184 int max_of_indices = 0;
2185 for (int j = 0, end = indices.size(); j < end; j++) {
2186 vector_indices[j] = indices[j];
2187 if (vector_indices[j] > max_of_indices) {
2188 max_of_indices = vector_indices[j];
2189 }
2190 }
2191 tflite::SparseIndexVector indices_type;
2192 BufferOffset<void> array_indices;
2193 if (max_of_indices <= UINT8_MAX) {
2194 indices_type = tflite::SparseIndexVector_Uint8Vector;
2195 std::vector<uint8_t> uint8_vector(vector_indices.begin(),
2196 vector_indices.end());
2197 array_indices = tflite::CreateUint8Vector(
2198 builder_, builder_.CreateVector(uint8_vector))
2199 .Union();
2200 } else if (max_of_indices <= UINT16_MAX) {
2201 indices_type = tflite::SparseIndexVector_Uint16Vector;
2202 std::vector<uint16_t> uint16_vector(vector_indices.begin(),
2203 vector_indices.end());
2204 array_indices = tflite::CreateUint16Vector(
2205 builder_, builder_.CreateVector(uint16_vector))
2206 .Union();
2207 } else {
2208 indices_type = tflite::SparseIndexVector_Int32Vector;
2209 array_indices = tflite::CreateInt32Vector(
2210 builder_, builder_.CreateVector(vector_indices))
2211 .Union();
2212 }
2213
2214 fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
2215 builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
2216 array_segments, indices_type, array_indices);
2217 }
2218 }
2219
2220 std::vector<int> traversal_order(dim_size);
2221 for (int i = 0; i < dim_size; i++) {
2222 traversal_order[i] = s_attr.getTraversalOrder()[i];
2223 }
2224 const int block_map_size = s_attr.getBlockMap().size();
2225 std::vector<int> block_map(block_map_size);
2226 for (int i = 0; i < block_map_size; i++) {
2227 block_map[i] = s_attr.getBlockMap()[i];
2228 }
2229
2230 return tflite::CreateSparsityParameters(
2231 builder_, builder_.CreateVector(traversal_order),
2232 builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata));
2233}
2234
2235std::vector<std::pair<int, int>> Translator::ExtractControlEdges(
2236 mlir::Block* block) {
2237 std::vector<std::pair<int, int>> control_edges;
2238
2239 mlir::IRRewriter rewriter(block->getParentOp()->getContext());
2240
2241 // Since we're modifying *block, we store integer offsets to block->begin().
2242 llvm::DenseMap<Operation*, int> control_nodes_at;
2243 std::vector<Operation*> control_nodes;
2244 for (const auto& item : llvm::enumerate(*block)) {
2245 if (llvm::isa<mlir::TFL::ControlNodeOp>(item.value())) {
2246 control_nodes.push_back(&item.value());
2247 control_nodes_at.try_emplace(&item.value(), item.index());
2248 }
2249 }
2250
2251 for (auto outer_op : control_nodes) {
2252 auto control_node_op = dyn_cast<mlir::TFL::ControlNodeOp>(outer_op);
2253 auto* inner_op = &control_node_op.body().front().front();
2254 auto control_token = control_node_op.control();
2255
2256 // Now go through all uses. Since *block is in executable order, control
2257 // edges always point to operations we haven't modified yet.
2258 for (auto& use : control_token.getUses()) {
2259 auto owner = use.getOwner();
2260 // Control tokens can only be consumed by other ControlNodeOps,
2261 assert(llvm::isa<mlir::TFL::ControlNodeOp>(owner));
2262 assert(control_nodes_at.find(owner) != control_nodes_at.end());
2263 // Control edge in terms of offsets.
2264 control_edges.emplace_back(control_nodes_at[outer_op],
2265 control_nodes_at[owner]);
2266 }
2267 control_token.dropAllUses();
2268
2269 // Replace the ControlNodeOp with the wrapped operation.
2270 rewriter.setInsertionPointAfter(outer_op);
2271 auto* cloned_inner = rewriter.clone(*inner_op);
2272 for (auto it :
2273 llvm::zip(control_node_op.outputs(), cloned_inner->getResults())) {
2274 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2275 }
2276 rewriter.eraseOp(outer_op);
2277 }
2278 return control_edges;
2279}
2280
2281} // namespace
2282
2283namespace tflite {
2284
2285bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module,
2286 const FlatbufferExportOptions& options,
2287 std::string* serialized_flatbuffer) {
2288 auto maybe_translated = Translator::Translate(
2289 module, options.toco_flags, options.saved_model_tags,
2290 options.op_or_arg_name_mapper, options.metadata);
2291 if (!maybe_translated) return false;
2292 *serialized_flatbuffer = std::move(*maybe_translated);
2293 return true;
2294}
2295
2296} // namespace tflite
2297