1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" |
17 | |
18 | #include <stddef.h> |
19 | #include <stdlib.h> |
20 | |
21 | #include <algorithm> |
22 | #include <cstdint> |
23 | #include <memory> |
24 | #include <string> |
25 | #include <utility> |
26 | #include <vector> |
27 | |
28 | #include "absl/base/attributes.h" |
29 | #include "absl/container/flat_hash_map.h" |
30 | #include "absl/container/flat_hash_set.h" |
31 | #include "absl/strings/match.h" |
32 | #include "absl/strings/str_cat.h" |
33 | #include "absl/strings/str_format.h" |
34 | #include "absl/strings/str_join.h" |
35 | #include "absl/strings/string_view.h" |
36 | #include "flatbuffers/flatbuffers.h" // from @flatbuffers |
37 | #include "flatbuffers/flexbuffers.h" // from @flatbuffers |
38 | #include "llvm/ADT/ArrayRef.h" |
39 | #include "llvm/ADT/DenseMap.h" |
40 | #include "llvm/ADT/None.h" |
41 | #include "llvm/ADT/Optional.h" |
42 | #include "llvm/ADT/STLExtras.h" |
43 | #include "llvm/ADT/StringRef.h" |
44 | #include "llvm/Support/Casting.h" |
45 | #include "llvm/Support/CommandLine.h" |
46 | #include "llvm/Support/FormatVariadic.h" |
47 | #include "llvm/Support/ToolOutputFile.h" |
48 | #include "llvm/Support/raw_ostream.h" |
49 | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project |
50 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
51 | #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project |
52 | #include "mlir/IR/Attributes.h" // from @llvm-project |
53 | #include "mlir/IR/Builders.h" // from @llvm-project |
54 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
55 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
56 | #include "mlir/IR/Location.h" // from @llvm-project |
57 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
58 | #include "mlir/IR/Operation.h" // from @llvm-project |
59 | #include "mlir/IR/Types.h" // from @llvm-project |
60 | #include "mlir/IR/Value.h" // from @llvm-project |
61 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
62 | #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project |
63 | #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" |
64 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
65 | #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" |
66 | #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" |
67 | #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" |
68 | #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" |
69 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" |
70 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
71 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" |
72 | #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" |
73 | #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" |
74 | #include "tensorflow/compiler/xla/statusor.h" |
75 | #include "tensorflow/core/framework/attr_value.pb.h" |
76 | #include "tensorflow/core/framework/node_def.pb.h" |
77 | #include "tensorflow/core/platform/errors.h" |
78 | #include "tensorflow/core/platform/logging.h" |
79 | #include "tensorflow/core/platform/status.h" |
80 | #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" |
81 | #include "tensorflow/lite/experimental/remat/metadata_util.h" |
82 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
83 | #include "tensorflow/lite/schema/schema_conversion_utils.h" |
84 | #include "tensorflow/lite/schema/schema_generated.h" |
85 | #include "tensorflow/lite/string_util.h" |
86 | #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" |
87 | #include "tensorflow/lite/tools/versioning/op_version.h" |
88 | #include "tensorflow/lite/tools/versioning/runtime_version.h" |
89 | #include "tensorflow/lite/version.h" |
90 | |
91 | using llvm::dyn_cast; |
92 | using llvm::formatv; |
93 | using llvm::isa; |
94 | using llvm::Optional; |
95 | using llvm::StringRef; |
96 | using llvm::Twine; |
97 | using mlir::Dialect; |
98 | using mlir::ElementsAttr; |
99 | using mlir::MLIRContext; |
100 | using mlir::ModuleOp; |
101 | using mlir::NoneType; |
102 | using mlir::Operation; |
103 | using mlir::Region; |
104 | using mlir::StringAttr; |
105 | using mlir::TensorType; |
106 | using mlir::Type; |
107 | using mlir::UnknownLoc; |
108 | using mlir::Value; |
109 | using mlir::WalkResult; |
110 | using mlir::func::FuncOp; |
111 | using tensorflow::OpOrArgLocNameMapper; |
112 | using tensorflow::OpOrArgNameMapper; |
113 | using tensorflow::Status; |
114 | using tflite::flex::IsAllowlistedFlexOp; |
115 | using xla::StatusOr; |
116 | |
117 | template <typename T> |
118 | using BufferOffset = flatbuffers::Offset<T>; |
119 | |
120 | template <typename T> |
121 | using VectorBufferOffset = flatbuffers::Offset<flatbuffers::Vector<T>>; |
122 | |
123 | using CustomOptionsOffset = VectorBufferOffset<uint8_t>; |
124 | |
125 | namespace error = tensorflow::error; |
126 | namespace tfl = mlir::TFL; |
127 | |
128 | ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex" ; |
129 | |
130 | // Use initial buffer size in flatbuffer builder to be same as the initial size |
131 | // used by the TOCO export. (It does not explain rationale for this choice.) |
132 | constexpr size_t kInitialBufferSize = 10240; |
133 | |
134 | // Set `isSigned` to false if the `type` is an 8-bit unsigned integer type. |
135 | // Since tflite doesn't support unsigned for other types, returns error if |
136 | // `isSigned` is set to false for other types. |
137 | static StatusOr<tflite::TensorType> GetTFLiteType(Type type, |
138 | bool is_signed = true) { |
139 | if (!is_signed && type.isSignlessInteger(8)) { |
140 | return tflite::TensorType_UINT8; |
141 | } |
142 | if (!is_signed) { |
143 | return Status(error::INVALID_ARGUMENT, |
144 | "'isSigned' can only be set for 8-bits integer type" ); |
145 | } |
146 | |
147 | if (type.isF32()) { |
148 | return tflite::TensorType_FLOAT32; |
149 | } else if (type.isF16()) { |
150 | return tflite::TensorType_FLOAT16; |
151 | } else if (type.isF64()) { |
152 | return tflite::TensorType_FLOAT64; |
153 | } else if (type.isa<mlir::TF::StringType>()) { |
154 | return tflite::TensorType_STRING; |
155 | } else if (type.isa<mlir::TF::Quint8Type>()) { |
156 | return tflite::TensorType_UINT8; |
157 | } else if (auto complex_type = type.dyn_cast<mlir::ComplexType>()) { |
158 | auto ftype = complex_type.getElementType(); |
159 | if (ftype.isF32()) { |
160 | return tflite::TensorType_COMPLEX64; |
161 | } |
162 | if (ftype.isF64()) { |
163 | return tflite::TensorType_COMPLEX128; |
164 | } |
165 | return Status(error::INVALID_ARGUMENT, "Unsupported type" ); |
166 | } else if (auto itype = type.dyn_cast<mlir::IntegerType>()) { |
167 | switch (itype.getWidth()) { |
168 | case 1: |
169 | return tflite::TensorType_BOOL; |
170 | case 8: |
171 | return itype.isUnsigned() ? tflite::TensorType_UINT8 |
172 | : tflite::TensorType_INT8; |
173 | case 16: |
174 | return itype.isUnsigned() ? tflite::TensorType_UINT16 |
175 | : tflite::TensorType_INT16; |
176 | case 32: |
177 | return itype.isUnsigned() ? tflite::TensorType_UINT32 |
178 | : tflite::TensorType_INT32; |
179 | case 64: |
180 | return itype.isUnsigned() ? tflite::TensorType_UINT64 |
181 | : tflite::TensorType_INT64; |
182 | } |
183 | } else if (auto q_uniform_type = |
184 | type.dyn_cast<mlir::quant::UniformQuantizedType>()) { |
185 | return GetTFLiteType(q_uniform_type.getStorageType(), |
186 | q_uniform_type.isSigned()); |
187 | } else if (auto q_peraxis_type = |
188 | type.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) { |
189 | return GetTFLiteType(q_peraxis_type.getStorageType(), |
190 | q_peraxis_type.isSigned()); |
191 | } else if (auto q_calibrated_type = |
192 | type.dyn_cast<mlir::quant::CalibratedQuantizedType>()) { |
193 | return GetTFLiteType(q_calibrated_type.getExpressedType()); |
194 | } else if (type.isa<mlir::TF::ResourceType>()) { |
195 | return tflite::TensorType_RESOURCE; |
196 | } else if (type.isa<mlir::TF::VariantType>()) { |
197 | return tflite::TensorType_VARIANT; |
198 | } |
199 | // TFLite export fills FLOAT32 for unknown data types. Returning an error |
200 | // for now for safety and this could be revisited when required. |
201 | return Status(error::INVALID_ARGUMENT, "Unsupported type" ); |
202 | } |
203 | |
204 | static bool IsConst(Operation* op) { |
205 | return isa<mlir::func::ConstantOp, mlir::arith::ConstantOp, mlir::TF::ConstOp, |
206 | tfl::ConstOp, tfl::QConstOp, tfl::SparseConstOp, |
207 | tfl::SparseQConstOp, mlir::TFL::NoValueOp>(op); |
208 | } |
209 | |
210 | static bool IsTFResourceOp(Operation* op) { |
211 | for (const auto& operand : op->getOperands()) { |
212 | auto elementType = getElementTypeOrSelf(operand.getType()); |
213 | if (elementType.isa<mlir::TF::ResourceType>()) { |
214 | return true; |
215 | } |
216 | } |
217 | for (const auto& result : op->getResults()) { |
218 | auto elementType = getElementTypeOrSelf(result.getType()); |
219 | if (elementType.isa<mlir::TF::ResourceType>()) { |
220 | return true; |
221 | } |
222 | } |
223 | return false; |
224 | } |
225 | |
226 | // Returns whether the current op is not supported by the TF Lite runtime. |
227 | static bool IsUnsupportedFlexOp(const std::string& op_name) { |
228 | return op_name == "PartitionedCall" || op_name == "StatefulPartitionedCall" ; |
229 | } |
230 | |
231 | // Create description of operation that could not be converted. |
232 | static std::string GetOpDescriptionForDebug(Operation* inst) { |
233 | const int kLargeElementsAttr = 16; |
234 | std::string op_str; |
235 | llvm::raw_string_ostream os(op_str); |
236 | inst->getName().print(os); |
237 | os << "(" ; |
238 | if (!inst->getOperandTypes().empty()) { |
239 | bool first = true; |
240 | for (Type operand_type : inst->getOperandTypes()) { |
241 | os << (!first ? ", " : "" ); |
242 | first = false; |
243 | os << operand_type; |
244 | } |
245 | } |
246 | os << ") -> (" ; |
247 | if (!inst->getResultTypes().empty()) { |
248 | bool first = true; |
249 | for (Type result_type : inst->getResultTypes()) { |
250 | os << (!first ? ", " : "" ); |
251 | first = false; |
252 | os << result_type; |
253 | } |
254 | } |
255 | os << ")" ; |
256 | // Print out attributes except for large elementsattributes (which should |
257 | // rarely be the cause why the legalization didn't happen). |
258 | if (!inst->getAttrDictionary().empty()) { |
259 | os << " : {" ; |
260 | bool first = true; |
261 | for (auto& named_attr : inst->getAttrDictionary()) { |
262 | os << (!first ? ", " : "" ); |
263 | first = false; |
264 | os << named_attr.getName().getValue() << " = " ; |
265 | if (auto element_attr = named_attr.getValue().dyn_cast<ElementsAttr>()) { |
266 | if (element_attr.getNumElements() <= kLargeElementsAttr) { |
267 | element_attr.print(os); |
268 | } else { |
269 | os << "<large>" ; |
270 | } |
271 | } else { |
272 | named_attr.getValue().print(os); |
273 | } |
274 | } |
275 | os << "}" ; |
276 | } |
277 | return os.str(); |
278 | } |
279 | |
280 | // Create a summary with the given information regarding op names and |
281 | // descriptions. |
282 | static std::string GetOpsSummary( |
283 | const std::map<std::string, std::set<std::string>>& ops, |
284 | const std::string& summary_title) { |
285 | std::string op_str; |
286 | llvm::raw_string_ostream os(op_str); |
287 | |
288 | std::vector<std::string> keys; |
289 | keys.reserve(ops.size()); |
290 | |
291 | std::vector<std::string> values; |
292 | values.reserve(ops.size()); |
293 | |
294 | for (auto const& op_name_and_details : ops) { |
295 | keys.push_back(op_name_and_details.first); |
296 | for (auto const& op_detail : op_name_and_details.second) { |
297 | values.push_back(op_detail); |
298 | } |
299 | } |
300 | |
301 | os << summary_title << " ops: " << absl::StrJoin(keys, ", " ) << "\n" ; |
302 | os << "Details:\n\t" << absl::StrJoin(values, "\n\t" ); |
303 | |
304 | return os.str(); |
305 | } |
306 | |
307 | template <typename T> |
308 | static bool HasValidTFLiteType(Value value, T& error_handler) { |
309 | // None type is allowed to represent unspecified operands. |
310 | if (value.getType().isa<NoneType>()) return true; |
311 | |
312 | auto type = value.getType().dyn_cast<TensorType>(); |
313 | if (!type) { |
314 | if (auto op = value.getDefiningOp()) { |
315 | error_handler.emitError() |
316 | << '\'' << op << "' should produce value of tensor type instead of " |
317 | << value.getType(); |
318 | return false; |
319 | } |
320 | error_handler.emitError("expected tensor type, got " ) << value.getType(); |
321 | return false; |
322 | } |
323 | |
324 | Type element_type = type.getElementType(); |
325 | auto status = GetTFLiteType(element_type); |
326 | if (!status.ok()) { |
327 | return error_handler.emitError( |
328 | formatv("Failed to convert element type '{0}': {1}" , |
329 | element_type, status.status().error_message())), |
330 | false; |
331 | } |
332 | return true; |
333 | } |
334 | |
335 | // Returns true if the module holds all the invariants expected by the |
336 | // Translator class. |
337 | // TODO(hinsu): Now that translation is done by making a single pass over the |
338 | // MLIR module, consider inlining these validation checks at the place where |
339 | // these invariants are assumed instead of checking upfront. |
340 | static bool IsValidTFLiteMlirModule(ModuleOp module) { |
341 | MLIRContext* context = module.getContext(); |
342 | |
343 | // Verify that module has a function named main. |
344 | FuncOp main_fn = module.lookupSymbol<FuncOp>("main" ); |
345 | if (!main_fn) { |
346 | int entry_func_count = 0; |
347 | for (auto fn : module.getOps<FuncOp>()) { |
348 | auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function" ); |
349 | if (attrs && !attrs.empty()) { |
350 | ++entry_func_count; |
351 | } |
352 | } |
353 | |
354 | // Verify that module has a least one enrty function. |
355 | if (entry_func_count == 0) { |
356 | return emitError(UnknownLoc::get(context), |
357 | "should have a least one entry function" ), |
358 | false; |
359 | } |
360 | } |
361 | |
362 | for (auto fn : module.getOps<FuncOp>()) { |
363 | if (!llvm::hasSingleElement(fn)) { |
364 | return fn.emitError("should have exactly one basic block" ), false; |
365 | } |
366 | auto& bb = fn.front(); |
367 | |
368 | for (auto arg : bb.getArguments()) { |
369 | if (!HasValidTFLiteType(arg, fn)) { |
370 | auto elementType = getElementTypeOrSelf(arg.getType()); |
371 | if (elementType.isa<mlir::TF::VariantType>()) { |
372 | return fn.emitError( |
373 | "function argument uses variant type. Currently, the " |
374 | "variant type is not natively supported in TFLite. Please " |
375 | "consider not using the variant type: " ) |
376 | << arg.getType(), |
377 | false; |
378 | } |
379 | return fn.emitError("invalid TFLite type: " ) << arg.getType(), false; |
380 | } |
381 | } |
382 | |
383 | // Verify that all operations except the terminator have exactly one |
384 | // result of type supported by TFLite (or is a ControlType, which |
385 | // will be removed later by ExtractControlEdges.) |
386 | for (auto& inst : bb) { |
387 | if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break; |
388 | |
389 | for (auto result : inst.getResults()) { |
390 | if (result.getType().isa<mlir::TFL::ControlType>()) continue; |
391 | if (!HasValidTFLiteType(result, inst)) { |
392 | auto elementType = getElementTypeOrSelf(result.getType()); |
393 | if (elementType.isa<mlir::TF::VariantType>()) { |
394 | return inst.emitError( |
395 | "operand result uses variant type. Currently, the " |
396 | "variant type is not natively supported in TFLite. " |
397 | "Please " |
398 | "consider not using the variant type: " ) |
399 | << result.getType(), |
400 | false; |
401 | } |
402 | return fn.emitError("invalid TFLite type: " ) << result.getType(), |
403 | false; |
404 | } |
405 | } |
406 | } |
407 | } |
408 | |
409 | return true; |
410 | } |
411 | |
412 | static std::unique_ptr<::tensorflow::NodeDef> GetTensorFlowNodeDef( |
413 | ::mlir::Operation* inst) { |
414 | // We pass empty string for the original node_def name since Flex runtime |
415 | // does not care about this being set correctly on node_def. There is no |
416 | // "easy" (see b/120948529) way yet to get this from MLIR inst. |
417 | auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( |
418 | inst, /*name=*/"" , /*ignore_unregistered_attrs=*/true); |
419 | if (!status_or_node_def.ok()) { |
420 | inst->emitOpError( |
421 | Twine("failed to obtain TensorFlow nodedef with status: " + |
422 | status_or_node_def.status().ToString())); |
423 | return {}; |
424 | } |
425 | return std::move(status_or_node_def.value()); |
426 | } |
427 | |
428 | // Converts a mlir padding StringRef to TfLitePadding. |
429 | // Returns llvm::None if conversion fails. |
430 | static Optional<TfLitePadding> GetTflitePadding(Operation* inst, |
431 | llvm::StringRef padding) { |
432 | const tflite::Padding padding_attr = |
433 | std::move(llvm::StringSwitch<tflite::Padding>(padding) |
434 | .Case("SAME" , tflite::Padding_SAME) |
435 | .Case("VALID" , tflite::Padding_VALID)); |
436 | if (padding_attr == tflite::Padding_SAME) { |
437 | return kTfLitePaddingSame; |
438 | } |
439 | if (padding_attr == tflite::Padding_VALID) { |
440 | return kTfLitePaddingValid; |
441 | } |
442 | |
443 | return inst->emitOpError() << "Invalid padding attribute: " << padding, |
444 | llvm::None; |
445 | } |
446 | |
447 | // Extracts TfLitePoolParams from a TFL custom op. |
448 | // Template parameter, TFLOp, should be a TFL custom op containing attributes |
449 | // generated from TfLitePoolParams. |
450 | // Returns llvm::None if conversion fails. |
451 | template <typename TFLOp> |
452 | static Optional<TfLitePoolParams> GetTflitePoolParams(Operation* inst, |
453 | TFLOp op) { |
454 | TfLitePoolParams pool_params; |
455 | pool_params.stride_height = op.stride_h().getSExtValue(); |
456 | pool_params.stride_width = op.stride_w().getSExtValue(); |
457 | pool_params.filter_height = op.filter_h().getSExtValue(); |
458 | pool_params.filter_width = op.filter_w().getSExtValue(); |
459 | const auto padding = GetTflitePadding(inst, op.padding()); |
460 | if (padding) { |
461 | pool_params.padding = *padding; |
462 | pool_params.activation = kTfLiteActNone; |
463 | pool_params.computed.padding = TfLitePaddingValues{0, 0, 0, 0}; |
464 | return pool_params; |
465 | } |
466 | |
467 | return llvm::None; |
468 | } |
469 | |
470 | namespace { |
471 | |
472 | // Helper struct that wraps inputs/outputs of a single SignatureDef. |
473 | struct SignatureDefData { |
474 | // Note, we are using maps here to make order deterministic |
475 | // for easily testing only. |
476 | |
477 | // Inputs defined in the signature def mapped to tensor names. |
478 | std::map<std::string, std::string> inputs; |
479 | // Outputs defined in the signature def mapped to tensor names. |
480 | std::map<std::string, std::string> outputs; |
481 | // Signature key. |
482 | std::string signature_key; |
483 | // Subgraph index. |
484 | uint32_t subgraph_index; |
485 | }; |
486 | |
487 | // Translates an MLIR module in TFLite dialect to TFLite FlatBuffer. |
488 | class Translator { |
489 | public: |
490 | // Translates the given MLIR module into TFLite FlatBuffer format and returns |
491 | // the serialized output. Returns llvm::None on unsupported, invalid inputs or |
492 | // internal error. |
493 | static Optional<std::string> Translate( |
494 | ModuleOp module, const toco::TocoFlags& toco_flags, |
495 | const std::unordered_set<std::string>& tags, |
496 | OpOrArgNameMapper* op_or_arg_name_mapper, |
497 | const std::map<std::string, std::string>& metadata); |
498 | |
499 | private: |
500 | enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; |
501 | explicit Translator(ModuleOp module, const toco::TocoFlags& toco_flags, |
502 | const std::unordered_set<std::string>& saved_model_tags, |
503 | OpOrArgNameMapper* op_or_arg_name_mapper, |
504 | const std::map<std::string, std::string>& metadata) |
505 | : module_(module), |
506 | name_mapper_(*op_or_arg_name_mapper), |
507 | builder_(kInitialBufferSize), |
508 | saved_model_tags_(saved_model_tags), |
509 | allow_all_select_tf_ops_(toco_flags.allow_all_select_tf_ops()), |
510 | select_user_tf_ops_(toco_flags.select_user_tf_ops().begin(), |
511 | toco_flags.select_user_tf_ops().end()), |
512 | metadata_(metadata), |
513 | supported_backends_(toco_flags.supported_backends().begin(), |
514 | toco_flags.supported_backends().end()) { |
515 | // The first buffer must be empty according to the schema definition. |
516 | empty_buffer_ = tflite::CreateBuffer(builder_); |
517 | buffers_.push_back(empty_buffer_); |
518 | if (!toco_flags.force_select_tf_ops()) { |
519 | enabled_op_types_.emplace(OpType::kTfliteBuiltin); |
520 | } |
521 | if (toco_flags.enable_select_tf_ops()) { |
522 | enabled_op_types_.emplace(OpType::kSelectTf); |
523 | } |
524 | if (toco_flags.allow_custom_ops()) { |
525 | enabled_op_types_.emplace(OpType::kCustomOp); |
526 | } |
527 | tf_dialect_ = |
528 | module.getContext()->getOrLoadDialect<mlir::TF::TensorFlowDialect>(); |
529 | tfl_dialect_ = module.getContext() |
530 | ->getOrLoadDialect<mlir::TFL::TensorFlowLiteDialect>(); |
531 | // Right now the TF executor dialect is still needed to build NodeDef. |
532 | module.getContext() |
533 | ->getOrLoadDialect<mlir::tf_executor::TensorFlowExecutorDialect>(); |
534 | } |
535 | |
536 | Optional<std::string> TranslateInternal(); |
537 | |
538 | // Returns TFLite buffer populated with constant value if the operation is |
539 | // TFLite constant operation. Otherwise, returns an empty buffer. Emits error |
540 | // and returns llvm::None on failure. |
541 | Optional<BufferOffset<tflite::Buffer>> BuildBuffer(Operation* inst); |
542 | |
543 | // Build TFLite tensor from the given type. This function is for tfl.lstm |
544 | // intermediates, which should have UniformQuantizedType. |
545 | Optional<BufferOffset<tflite::Tensor>> BuildTensorFromType( |
546 | mlir::Type type, const std::string& name); |
547 | |
548 | // Builds TF::VariantType from the given element type. Returns llvm::None if |
549 | // failure. Returns empty vector if the element type is not TF::VariantType or |
550 | // there is empty TensorType in the TF::VariantType. |
551 | Optional<std::vector<BufferOffset<tflite::VariantSubType>>> |
552 | BuildTFVariantType(mlir::Type element_type); |
553 | |
554 | // Builds TFLite tensor from the given value. `buffer_idx` is index of the |
555 | // corresponding buffer. Emits error and returns llvm::None on failure. |
556 | Optional<BufferOffset<tflite::Tensor>> BuildTensor( |
557 | Value value, const std::string& name, unsigned buffer_idx, |
558 | const Optional<BufferOffset<tflite::QuantizationParameters>>& |
559 | quant_parameters); |
560 | |
561 | // TODO(b/137395003): Legalize tf.IfOp to TFLite dialect, and change the |
562 | // following method to handle TFL::IfOp. |
563 | BufferOffset<tflite::Operator> BuildIfOperator( |
564 | mlir::TF::IfOp op, const std::vector<int32_t>& operands, |
565 | const std::vector<int32_t>& results); |
566 | |
567 | // Build while operator where cond & body are regions. |
568 | Optional<BufferOffset<tflite::Operator>> BuildWhileOperator( |
569 | mlir::TFL::WhileOp op, const std::vector<int32_t>& operands, |
570 | const std::vector<int32_t>& results); |
571 | |
572 | // Build call once operator. |
573 | BufferOffset<tflite::Operator> BuildCallOnceOperator( |
574 | mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands, |
575 | const std::vector<int32_t>& results); |
576 | |
577 | BufferOffset<tflite::Operator> BuildNumericVerifyOperator( |
578 | mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, |
579 | const std::vector<int32_t>& results); |
580 | |
581 | BufferOffset<tflite::Operator> BuildCustomOperator( |
582 | Operation* inst, mlir::TFL::CustomOp op, |
583 | const std::vector<int32_t>& operands, |
584 | const std::vector<int32_t>& results); |
585 | |
586 | Optional<CustomOptionsOffset> CreateFlexOpCustomOptions( |
587 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); |
588 | |
589 | Optional<CustomOptionsOffset> CreateCustomOpCustomOptions( |
590 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); |
591 | |
592 | std::unique_ptr<flexbuffers::Builder> CreateFlexBuilderWithNodeAttrs( |
593 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); |
594 | |
595 | // Returns opcode index for op identified by the op_name, if already |
596 | // available. Otherwise, creates a new OperatorCode using the given `builtin` |
597 | // operator and associates it with `op_name`. |
598 | uint32_t GetOpcodeIndex(const std::string& op_name, |
599 | tflite::BuiltinOperator builtin); |
600 | |
601 | // Builds operator for the given operation with specified operand and result |
602 | // tensor indices. Emits an error and returns llvm::None on failure. |
603 | Optional<BufferOffset<tflite::Operator>> BuildOperator( |
604 | Operation* inst, std::vector<int32_t> operands, |
605 | const std::vector<int32_t>& results, |
606 | const std::vector<int32_t>& intermediates); |
607 | |
608 | // Returns the quantization parameters for output value of "quant.stats" op. |
609 | BufferOffset<tflite::QuantizationParameters> |
610 | GetQuantizationForQuantStatsOpOutput(mlir::quantfork::StatisticsOp stats_op); |
611 | |
612 | // Build a subgraph with a given name out of the region either corresponding |
613 | // to a function's body or while op. Modifies *region by calling |
614 | // ExtractControlEdges. |
615 | Optional<BufferOffset<tflite::SubGraph>> BuildSubGraph( |
616 | const std::string& name, Region* region, const int index); |
617 | |
618 | // Modifies *block by unwrapping all ControlNodeOps. The DAG of the control |
619 | // dependencies is returned as a vector of its edges, with node indices into |
620 | // *block. |
621 | std::vector<std::pair<int, int>> ExtractControlEdges(mlir::Block* block); |
622 | |
623 | // Builds Metadata with the given `name` and buffer `content`. |
624 | BufferOffset<tflite::Metadata> BuildMetadata(StringRef name, |
625 | StringRef content); |
626 | |
627 | // Encodes the `tfl.metadata` dictionary attribute of the module to the |
628 | // metadata section in the final model. |
629 | Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>> |
630 | CreateMetadataVector(); |
631 | |
632 | // Builds and returns list of tfl.SignatureDef sections in the model. |
633 | Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>> |
634 | CreateSignatureDefs(const std::vector<SignatureDefData>& signature_defs); |
635 | |
636 | // Returns list of offsets for the passed 'items' in TensorMap structure |
637 | // inside the flatbuffer. |
638 | // 'items' is a map from tensor name in signatureDef to tensor name in |
639 | // the subgraph, specified by the 'subgraph_index' argument. |
640 | std::vector<BufferOffset<tflite::TensorMap>> GetList( |
641 | const int subgraph_index, |
642 | const std::map<std::string, std::string>& items); |
643 | |
644 | // Uses the tf.entry_function attribute (if set) to initialize the op to name |
645 | // mapping. |
646 | void InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr); |
647 | |
648 | // Determines if the specified operation op's operand at operand_index |
649 | // is marked as a stateful operand. |
650 | bool IsStatefulOperand(mlir::Operation* op, int operand_index); |
651 | |
652 | // Returns a unique name for `val`. |
653 | std::string UniqueName(mlir::Value val); |
654 | |
655 | BufferOffset<tflite::SparsityParameters> BuildSparsityParameters( |
656 | const mlir::TFL::SparsityParameterAttr& s_attr); |
657 | |
658 | bool EstimateArithmeticCount(int64_t* count); |
659 | |
660 | // Check compatibility with GPU delegate and returns the compatibility. |
661 | bool CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer); |
662 | |
663 | ModuleOp module_; |
664 | |
665 | tensorflow::OpOrArgNameMapper& name_mapper_; |
666 | |
667 | flatbuffers::FlatBufferBuilder builder_; |
668 | BufferOffset<tflite::Buffer> empty_buffer_; |
669 | |
670 | std::vector<BufferOffset<tflite::Buffer>> buffers_; |
671 | // Maps subgraph index and tensor name in the graph to the tensor index. |
672 | absl::flat_hash_map<int, absl::flat_hash_map<std::string, int>> |
673 | tensor_index_map_; |
674 | |
675 | // Maps op name to index of the corresponding OperatorCode in opcodes_ vector. |
676 | absl::flat_hash_map<std::string, uint32_t> opcode_index_map_; |
677 | std::vector<BufferOffset<tflite::OperatorCode>> opcodes_; |
678 | |
679 | // Maps function name to index of the corresponding subgraph in the FlatBuffer |
680 | // model. |
681 | absl::flat_hash_map<std::string, int> subgraph_index_map_; |
682 | absl::flat_hash_set<OpType> enabled_op_types_; |
683 | |
684 | // Points to TensorFlow and TFLite dialects, respectively. nullptr if the |
685 | // dialect is not registered. |
686 | const Dialect* tf_dialect_; |
687 | const Dialect* tfl_dialect_; |
688 | |
689 | // The failed ops during legalization. |
690 | std::map<std::string, std::set<std::string>> failed_flex_ops_; |
691 | std::map<std::string, std::set<std::string>> failed_custom_ops_; |
692 | |
693 | // Ops to provide warning messages. |
694 | std::map<std::string, std::set<std::string>> custom_ops_; |
695 | std::map<std::string, std::set<std::string>> flex_ops_; |
696 | |
697 | // Resource ops to provide warning messages. |
698 | std::map<std::string, std::set<std::string>> resource_ops_; |
699 | |
700 | // Set of saved model tags, if any. |
701 | const std::unordered_set<std::string> saved_model_tags_; |
702 | // Allows automatic pass through of TF ops as select Tensorflow ops. |
703 | const bool allow_all_select_tf_ops_; |
704 | // User's defined ops allowed with Flex. |
705 | const std::unordered_set<std::string> select_user_tf_ops_; |
706 | // Map of key value pairs of metadata to export. |
707 | const std::map<std::string, std::string> metadata_; |
708 | // User's defined supported backends. |
709 | const std::unordered_set<std::string> supported_backends_; |
710 | // A mapping table to mlir::Operation objects for TFL subgraph and operator |
711 | // index in a flatbuffer. |
712 | std::vector<std::vector<Operation*>> subgraph_op_inst_map_; |
713 | |
714 | // Will be populated by ExtractControlEdges to contain the control |
715 | // dependencies contained in the ControlNodeOps. Will then be used to populate |
716 | // metadata in the exported flatbuffer file. |
717 | tflite::ModelControlDependencies model_control_dependencies_; |
718 | }; |
719 | |
720 | bool Translator::EstimateArithmeticCount(int64_t* count) { |
721 | int64_t result = 0; |
722 | bool encounter_undetermined_mac = false; |
723 | module_->walk([&](mlir::TFL::TflArithmeticCountOpInterface op) { |
724 | int64_t mac_count = op.GetArithmeticCount(op); |
725 | if (mac_count < 0) { |
726 | encounter_undetermined_mac = true; |
727 | return; |
728 | } |
729 | result += mac_count; |
730 | }); |
731 | |
732 | *count = result; |
733 | return !encounter_undetermined_mac; |
734 | } |
735 | |
736 | std::string Translator::UniqueName(mlir::Value val) { |
737 | return std::string(name_mapper_.GetUniqueName(val)); |
738 | } |
739 | |
740 | Optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer( |
741 | Operation* inst) { |
742 | ElementsAttr attr; |
743 | if (auto cst = dyn_cast<mlir::arith::ConstantOp>(inst)) { |
744 | // arith::ConstantOp have ElementAttr at this point due to validation of the |
745 | // TFLite module. |
746 | attr = cst.getValue().cast<ElementsAttr>(); |
747 | } else if (auto cst = dyn_cast<mlir::TF::ConstOp>(inst)) { |
748 | attr = cst.value(); |
749 | } else if (auto cst = dyn_cast<tfl::ConstOp>(inst)) { |
750 | attr = cst.value(); |
751 | } else if (auto cst = dyn_cast<tfl::QConstOp>(inst)) { |
752 | attr = cst.value(); |
753 | } else if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) { |
754 | attr = cst.compressed_data(); |
755 | } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) { |
756 | attr = cst.compressed_data(); |
757 | } else { |
758 | return empty_buffer_; |
759 | } |
760 | |
761 | tensorflow::Tensor tensor; |
762 | auto status = tensorflow::ConvertToTensor(attr, &tensor); |
763 | if (!status.ok()) { |
764 | inst->emitError( |
765 | Twine("failed to convert value attribute to tensor with error: " + |
766 | status.ToString())); |
767 | return llvm::None; |
768 | } |
769 | |
770 | // TensorFlow and TensorFlow Lite use different string encoding formats. |
771 | // Convert to TensorFlow Lite format is it's a constant string tensor. |
772 | if (tensor.dtype() == tensorflow::DT_STRING) { |
773 | ::tflite::DynamicBuffer dynamic_buffer; |
774 | auto flat = tensor.flat<::tensorflow::tstring>(); |
775 | for (int i = 0; i < flat.size(); ++i) { |
776 | const auto& str = flat(i); |
777 | dynamic_buffer.AddString(str.c_str(), str.length()); |
778 | } |
779 | char* tensor_buffer; |
780 | int bytes = dynamic_buffer.WriteToBuffer(&tensor_buffer); |
781 | auto buffer_data = |
782 | builder_.CreateVector(reinterpret_cast<uint8_t*>(tensor_buffer), bytes); |
783 | free(tensor_buffer); |
784 | return tflite::CreateBuffer(builder_, buffer_data); |
785 | } |
786 | |
787 | absl::string_view tensor_data = tensor.tensor_data(); |
788 | auto buffer_data = builder_.CreateVector( |
789 | reinterpret_cast<const uint8_t*>(tensor_data.data()), tensor_data.size()); |
790 | return tflite::CreateBuffer(builder_, buffer_data); |
791 | } |
792 | |
793 | Optional<std::vector<BufferOffset<tflite::VariantSubType>>> |
794 | Translator::BuildTFVariantType(mlir::Type element_type) { |
795 | std::vector<BufferOffset<tflite::VariantSubType>> variant_params; |
796 | auto variant_type = element_type.dyn_cast<mlir::TF::VariantType>(); |
797 | if (!variant_type) { |
798 | return variant_params; |
799 | } |
800 | |
801 | // We only support up to one nested type in tf_type.variant_type. |
802 | if (variant_type.getSubtypes().size() > 1) { |
803 | return llvm::None; |
804 | } |
805 | if (variant_type.getSubtypes().empty()) { |
806 | return variant_params; |
807 | } |
808 | mlir::TensorType tensor_type = variant_type.getSubtypes().front(); |
809 | tflite::TensorType tflite_element_type = |
810 | GetTFLiteType(tensor_type.getElementType()).value(); |
811 | std::vector<int32_t> shape; |
812 | if (tensor_type.hasRank()) { |
813 | llvm::ArrayRef<int64_t> shape_ref = tensor_type.getShape(); |
814 | shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); |
815 | } |
816 | |
817 | variant_params.push_back( |
818 | tflite::CreateVariantSubType(builder_, builder_.CreateVector(shape), |
819 | tflite_element_type, tensor_type.hasRank())); |
820 | return variant_params; |
821 | } |
822 | |
823 | Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensorFromType( |
824 | mlir::Type type, const std::string& name) { |
825 | auto tensor_type = type.cast<TensorType>(); |
826 | |
827 | llvm::ArrayRef<int64_t> shape_ref; |
828 | std::vector<int32_t> shape; |
829 | |
830 | if (tensor_type.hasRank()) { |
831 | if (tensor_type.hasStaticShape()) { |
832 | shape_ref = tensor_type.getShape(); |
833 | shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); |
834 | } else { |
835 | return llvm::None; |
836 | } |
837 | } |
838 | |
839 | auto element_type = tensor_type.getElementType(); |
840 | tflite::TensorType tflite_element_type = |
841 | GetTFLiteType(tensor_type.getElementType()).value(); |
842 | Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params = |
843 | BuildTFVariantType(element_type); |
844 | if (!variant_params.hasValue()) { |
845 | return llvm::None; |
846 | } |
847 | BufferOffset<tflite::QuantizationParameters> q_params = 0; |
848 | if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) { |
849 | std::vector<float> scales = {static_cast<float>(qtype.getScale())}; |
850 | std::vector<int64_t> zero_points = {qtype.getZeroPoint()}; |
851 | q_params = tflite::CreateQuantizationParameters( |
852 | builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales), |
853 | builder_.CreateVector<int64_t>(zero_points)); |
854 | } else if (auto qtype = |
855 | element_type |
856 | .dyn_cast<mlir::quant::CalibratedQuantizedType>()) { |
857 | std::vector<float> mins = {static_cast<float>(qtype.getMin())}; |
858 | std::vector<float> maxs = {static_cast<float>(qtype.getMax())}; |
859 | q_params = tflite::CreateQuantizationParameters( |
860 | builder_, builder_.CreateVector<float>(mins), |
861 | builder_.CreateVector<float>(maxs)); |
862 | } |
863 | return tflite::CreateTensor( |
864 | builder_, builder_.CreateVector(shape), tflite_element_type, |
865 | /*buffer=*/0, builder_.CreateString(name), q_params, |
866 | /*is_variable=*/false, /*sparsity=*/0, /*shape_signature=*/0, |
867 | /*has_rank=*/tensor_type.hasRank(), |
868 | variant_params->empty() ? 0 : builder_.CreateVector(*variant_params)); |
869 | } |
870 | |
871 | Optional<BufferOffset<tflite::Tensor>> Translator::BuildTensor( |
872 | Value value, const std::string& name, unsigned buffer_idx, |
873 | const Optional<BufferOffset<tflite::QuantizationParameters>>& |
874 | quant_parameters) { |
875 | auto type = value.getType().cast<TensorType>(); |
876 | |
877 | // TFLite requires tensor shape only for the inputs and constants. |
878 | // However, we output all known shapes for better round-tripping |
879 | auto check_shape = |
880 | [&](llvm::ArrayRef<int64_t> shape_ref) -> mlir::LogicalResult { |
881 | auto is_out_of_range = [](int64_t dim) { |
882 | return dim > std::numeric_limits<int32_t>::max(); |
883 | }; |
884 | |
885 | if (std::any_of(shape_ref.begin(), shape_ref.end(), is_out_of_range)) |
886 | return mlir::emitError( |
887 | value.getLoc(), |
888 | "result shape dimensions out of 32 bit int type range" ); |
889 | |
890 | return mlir::success(); |
891 | }; |
892 | |
893 | std::vector<int32_t> shape; |
894 | std::vector<int32_t> shape_signature; |
895 | auto* inst = value.getDefiningOp(); |
896 | if (type.hasStaticShape()) { |
897 | llvm::ArrayRef<int64_t> shape_ref = type.getShape(); |
898 | if (mlir::failed(check_shape(shape_ref))) return llvm::None; |
899 | |
900 | shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); |
901 | } else if (inst && IsConst(inst)) { |
902 | // Const op can have a result of dynamic shaped type (e.g. due to constant |
903 | // folding), but we can still derive the shape of a constant tensor for |
904 | // its attribute type. |
905 | auto tensor_attr = inst->getAttr("value" ).cast<mlir::TypedAttr>(); |
906 | llvm::ArrayRef<int64_t> shape_ref = |
907 | tensor_attr.getType().cast<TensorType>().getShape(); |
908 | if (mlir::failed(check_shape(shape_ref))) return llvm::None; |
909 | |
910 | shape = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); |
911 | } else if (type.hasRank()) { |
912 | llvm::ArrayRef<int64_t> shape_ref = type.getShape(); |
913 | if (mlir::failed(check_shape(shape_ref))) return llvm::None; |
914 | |
915 | shape.reserve(shape_ref.size()); |
916 | for (auto& dim : shape_ref) { |
917 | shape.push_back(dim == -1 ? 1 : dim); |
918 | } |
919 | shape_signature = std::vector<int32_t>(shape_ref.begin(), shape_ref.end()); |
920 | } |
921 | |
922 | BufferOffset<tflite::SparsityParameters> s_params = 0; |
923 | if (auto* inst = value.getDefiningOp()) { |
924 | if (auto cst = dyn_cast<tfl::SparseConstOp>(inst)) { |
925 | s_params = BuildSparsityParameters(cst.s_param()); |
926 | } else if (auto cst = dyn_cast<tfl::SparseQConstOp>(inst)) { |
927 | s_params = BuildSparsityParameters(cst.s_param()); |
928 | } |
929 | } |
930 | |
931 | Type element_type = type.getElementType(); |
932 | tflite::TensorType tflite_element_type = |
933 | GetTFLiteType(type.getElementType()).value(); |
934 | |
935 | Optional<std::vector<BufferOffset<tflite::VariantSubType>>> variant_params = |
936 | BuildTFVariantType(element_type); |
937 | if (!variant_params.hasValue()) { |
938 | return llvm::None; |
939 | } |
940 | |
941 | BufferOffset<tflite::QuantizationParameters> q_params; |
942 | if (auto qtype = element_type.dyn_cast<mlir::quant::UniformQuantizedType>()) { |
943 | std::vector<float> scales = {static_cast<float>(qtype.getScale())}; |
944 | std::vector<int64_t> zero_points = {qtype.getZeroPoint()}; |
945 | q_params = tflite::CreateQuantizationParameters( |
946 | // min and max values are not stored in the quantized type from MLIR, so |
947 | // both are set to 0 in the flatbuffer when they are exported. |
948 | builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales), |
949 | builder_.CreateVector<int64_t>(zero_points)); |
950 | } else if (auto qtype = |
951 | element_type |
952 | .dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) { |
953 | std::vector<float> scales(qtype.getScales().begin(), |
954 | qtype.getScales().end()); |
955 | std::vector<int64_t> zero_points(qtype.getZeroPoints().begin(), |
956 | qtype.getZeroPoints().end()); |
957 | q_params = tflite::CreateQuantizationParameters( |
958 | builder_, /*min=*/0, /*max=*/0, builder_.CreateVector<float>(scales), |
959 | builder_.CreateVector<int64_t>(zero_points), |
960 | tflite::QuantizationDetails_NONE, /*details=*/0, |
961 | qtype.getQuantizedDimension()); |
962 | } else if (quant_parameters.has_value()) { |
963 | q_params = quant_parameters.getValue(); |
964 | } else { |
965 | q_params = tflite::CreateQuantizationParameters(builder_); |
966 | } |
967 | // Check if the value's uses includes an op and usage at an operand index |
968 | // marked as a stateful. If so, set the tensor's is_variable as true |
969 | // This is v1 ref variable semantics in the TFLite runtime. |
970 | bool is_variable = false; |
971 | for (auto& use : value.getUses()) { |
972 | is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); |
973 | if (is_variable) { |
974 | break; |
975 | } |
976 | } |
977 | |
978 | bool has_rank = type.hasRank(); |
979 | |
980 | if (shape_signature.empty()) { |
981 | return tflite::CreateTensor( |
982 | builder_, builder_.CreateVector(shape), tflite_element_type, |
983 | (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, |
984 | /*is_variable=*/is_variable, s_params, /*shape_signature=*/0, |
985 | /*has_rank=*/has_rank, |
986 | variant_params->empty() ? 0 : builder_.CreateVector(*variant_params)); |
987 | } else { |
988 | return tflite::CreateTensor( |
989 | builder_, builder_.CreateVector(shape), tflite_element_type, |
990 | (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, |
991 | /*is_variable=*/is_variable, s_params, |
992 | /*shape_signature=*/builder_.CreateVector(shape_signature), |
993 | /*has_rank=*/has_rank, |
994 | variant_params->empty() ? 0 : builder_.CreateVector(*variant_params)); |
995 | } |
996 | } |
997 | |
998 | BufferOffset<tflite::Operator> Translator::BuildIfOperator( |
999 | mlir::TF::IfOp op, const std::vector<int32_t>& operands, |
1000 | const std::vector<int32_t>& results) { |
1001 | auto opcode_index = GetOpcodeIndex("if" , tflite::BuiltinOperator_IF); |
1002 | int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); |
1003 | int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); |
1004 | auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, |
1005 | else_subgraph_index) |
1006 | .Union(); |
1007 | auto inputs = builder_.CreateVector(operands); |
1008 | auto outputs = builder_.CreateVector(results); |
1009 | return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, |
1010 | tflite::BuiltinOptions_IfOptions, |
1011 | builtin_options); |
1012 | } |
1013 | |
1014 | BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator( |
1015 | mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands, |
1016 | const std::vector<int32_t>& results) { |
1017 | auto opcode_index = |
1018 | GetOpcodeIndex("call_once" , tflite::BuiltinOperator_CALL_ONCE); |
1019 | int init_subgraph_index = |
1020 | subgraph_index_map_.at(op.session_init_function().str()); |
1021 | auto builtin_options = |
1022 | tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union(); |
1023 | auto inputs = builder_.CreateVector(operands); |
1024 | auto outputs = builder_.CreateVector(results); |
1025 | return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, |
1026 | tflite::BuiltinOptions_CallOnceOptions, |
1027 | builtin_options); |
1028 | } |
1029 | |
1030 | Optional<BufferOffset<tflite::Operator>> Translator::BuildWhileOperator( |
1031 | mlir::TFL::WhileOp op, const std::vector<int32_t>& operands, |
1032 | const std::vector<int32_t>& results) { |
1033 | auto opcode_index = GetOpcodeIndex("while" , tflite::BuiltinOperator_WHILE); |
1034 | auto get_call_index = [&](mlir::Block& b) -> Optional<int> { |
1035 | if (b.getOperations().size() != 2) return llvm::None; |
1036 | if (auto call_op = dyn_cast<mlir::func::CallOp>(b.front())) |
1037 | return subgraph_index_map_.at(call_op.getCallee().str()); |
1038 | return llvm::None; |
1039 | }; |
1040 | auto body_subgraph_index = get_call_index(op.body().front()); |
1041 | auto cond_subgraph_index = get_call_index(op.cond().front()); |
1042 | if (!body_subgraph_index || !cond_subgraph_index) |
1043 | return op.emitOpError("only single call cond/body while export supported" ), |
1044 | llvm::None; |
1045 | auto builtin_options = |
1046 | tflite::CreateWhileOptions(builder_, *cond_subgraph_index, |
1047 | *body_subgraph_index) |
1048 | .Union(); |
1049 | auto inputs = builder_.CreateVector(operands); |
1050 | auto outputs = builder_.CreateVector(results); |
1051 | return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, |
1052 | tflite::BuiltinOptions_WhileOptions, |
1053 | builtin_options); |
1054 | } |
1055 | |
1056 | BufferOffset<tflite::Operator> Translator::BuildNumericVerifyOperator( |
1057 | mlir::TFL::NumericVerifyOp op, const std::vector<int32_t>& operands, |
1058 | const std::vector<int32_t>& results) { |
1059 | float tolerance = op.tolerance().convertToFloat(); |
1060 | bool log_if_failed = op.log_if_failed(); |
1061 | auto fbb = std::make_unique<flexbuffers::Builder>(); |
1062 | fbb->Map([&]() { |
1063 | fbb->Float("tolerance" , tolerance); |
1064 | fbb->Bool("log_if_failed" , log_if_failed); |
1065 | }); |
1066 | fbb->Finish(); |
1067 | auto f = std::unique_ptr<flexbuffers::Builder>(fbb.release()); |
1068 | auto custom_option = f->GetBuffer(); |
1069 | auto opcode_index = |
1070 | GetOpcodeIndex("NumericVerify" , tflite::BuiltinOperator_CUSTOM); |
1071 | return tflite::CreateOperator( |
1072 | builder_, opcode_index, builder_.CreateVector(operands), |
1073 | builder_.CreateVector(results), tflite::BuiltinOptions_NONE, |
1074 | /*builtin_options=*/0, builder_.CreateVector<uint8_t>(custom_option), |
1075 | tflite::CustomOptionsFormat_FLEXBUFFERS); |
1076 | } |
1077 | |
1078 | BufferOffset<tflite::Operator> Translator::BuildCustomOperator( |
1079 | Operation* inst, mlir::TFL::CustomOp op, |
1080 | const std::vector<int32_t>& operands, const std::vector<int32_t>& results) { |
1081 | const std::string attrs = |
1082 | op.custom_option().cast<mlir::TFL::ConstBytesAttr>().getValue().str(); |
1083 | std::vector<uint8_t> custom_option_vector(attrs.size()); |
1084 | memcpy(custom_option_vector.data(), attrs.data(), attrs.size()); |
1085 | auto opcode_index = |
1086 | GetOpcodeIndex(op.custom_code().str(), tflite::BuiltinOperator_CUSTOM); |
1087 | return tflite::CreateOperator( |
1088 | builder_, opcode_index, builder_.CreateVector(operands), |
1089 | builder_.CreateVector(results), tflite::BuiltinOptions_NONE, |
1090 | /*builtin_options=*/0, |
1091 | builder_.CreateVector<uint8_t>(custom_option_vector), |
1092 | tflite::CustomOptionsFormat_FLEXBUFFERS); |
1093 | } |
1094 | |
1095 | Optional<CustomOptionsOffset> Translator::CreateFlexOpCustomOptions( |
1096 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { |
1097 | std::string node_def_str; |
1098 | if (!node_def.SerializeToString(&node_def_str)) { |
1099 | return emitError(loc, "failed to serialize tensorflow node_def" ), |
1100 | llvm::None; |
1101 | } |
1102 | |
1103 | auto flex_builder = std::make_unique<flexbuffers::Builder>(); |
1104 | flex_builder->Vector([&]() { |
1105 | flex_builder->String(node_def.op()); |
1106 | flex_builder->String(node_def_str); |
1107 | }); |
1108 | flex_builder->Finish(); |
1109 | return builder_.CreateVector(flex_builder->GetBuffer()); |
1110 | } |
1111 | |
1112 | Optional<CustomOptionsOffset> Translator::CreateCustomOpCustomOptions( |
1113 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { |
1114 | auto flex_builder = CreateFlexBuilderWithNodeAttrs(node_def, loc); |
1115 | return builder_.CreateVector(flex_builder->GetBuffer()); |
1116 | } |
1117 | |
1118 | std::unique_ptr<flexbuffers::Builder> |
1119 | Translator::CreateFlexBuilderWithNodeAttrs( |
1120 | const ::tensorflow::NodeDef& node_def, const mlir::Location& loc) { |
1121 | auto flex_builder = std::make_unique<flexbuffers::Builder>(); |
1122 | size_t map_start = flex_builder->StartMap(); |
1123 | using Item = std::pair<std::string, ::tensorflow::AttrValue>; |
1124 | std::vector<Item> attrs(node_def.attr().begin(), node_def.attr().end()); |
1125 | std::sort(attrs.begin(), attrs.end(), |
1126 | [](Item& p1, Item& p2) -> bool { return p1.first < p2.first; }); |
1127 | for (const Item& pair : attrs) { |
1128 | const char* key = pair.first.c_str(); |
1129 | const ::tensorflow::AttrValue& attr = pair.second; |
1130 | switch (attr.value_case()) { |
1131 | case ::tensorflow::AttrValue::kS: |
1132 | flex_builder->String(key, attr.s()); |
1133 | break; |
1134 | case ::tensorflow::AttrValue::kType: { |
1135 | auto status_or_tfl_type = tflite::TfTypeToTflType(attr.type()); |
1136 | if (status_or_tfl_type.ok()) { |
1137 | flex_builder->Int(key, status_or_tfl_type.value()); |
1138 | } else { |
1139 | emitWarning(loc, "ignoring unsupported tensorflow type: " ) |
1140 | << std::to_string(attr.type()); |
1141 | } |
1142 | break; |
1143 | } |
1144 | case ::tensorflow::AttrValue::kI: |
1145 | flex_builder->Int(key, attr.i()); |
1146 | break; |
1147 | case ::tensorflow::AttrValue::kF: |
1148 | flex_builder->Float(key, attr.f()); |
1149 | break; |
1150 | case ::tensorflow::AttrValue::kB: |
1151 | flex_builder->Bool(key, attr.b()); |
1152 | break; |
1153 | case tensorflow::AttrValue::kList: |
1154 | if (attr.list().s_size() > 0) { |
1155 | auto start = flex_builder->StartVector(key); |
1156 | for (const std::string& v : attr.list().s()) { |
1157 | flex_builder->Add(v); |
1158 | } |
1159 | flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1160 | } else if (attr.list().i_size() > 0) { |
1161 | auto start = flex_builder->StartVector(key); |
1162 | for (const int64_t v : attr.list().i()) { |
1163 | flex_builder->Add(v); |
1164 | } |
1165 | flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1166 | } else if (attr.list().f_size() > 0) { |
1167 | auto start = flex_builder->StartVector(key); |
1168 | for (const float v : attr.list().f()) { |
1169 | flex_builder->Add(v); |
1170 | } |
1171 | flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false); |
1172 | } else { |
1173 | emitWarning(loc, |
1174 | "ignoring unsupported type in list attribute with key: " ) |
1175 | << key; |
1176 | } |
1177 | break; |
1178 | default: |
1179 | emitWarning(loc, "ignoring unsupported attribute type with key: " ) |
1180 | << key; |
1181 | break; |
1182 | } |
1183 | } |
1184 | flex_builder->EndMap(map_start); |
1185 | flex_builder->Finish(); |
1186 | return flex_builder; |
1187 | } |
1188 | |
1189 | uint32_t Translator::GetOpcodeIndex(const std::string& op_name, |
1190 | tflite::BuiltinOperator builtin) { |
1191 | auto it = opcode_index_map_.insert({op_name, 0}); |
1192 | |
1193 | // If the insert succeeded, the opcode has not been created already. Create a |
1194 | // new operator code and update its index value in the map. |
1195 | if (it.second) { |
1196 | it.first->second = opcodes_.size(); |
1197 | auto custom_code = builtin == tflite::BuiltinOperator_CUSTOM |
1198 | ? builder_.CreateString(op_name) |
1199 | : BufferOffset<flatbuffers::String>(); |
1200 | // Use version 0 for builtin op. This is a way to serialize version field to |
1201 | // flatbuffer (since 0 is non default) and it will be corrected later. |
1202 | int32_t op_version = builtin != tflite::BuiltinOperator_CUSTOM ? 0 : 1; |
1203 | opcodes_.push_back(CreateOperatorCode(builder_, /*builtin_code=*/builtin, |
1204 | custom_code, op_version)); |
1205 | } |
1206 | return it.first->second; |
1207 | } |
1208 | |
1209 | Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator( |
1210 | Operation* inst, std::vector<int32_t> operands, |
1211 | const std::vector<int32_t>& results, |
1212 | const std::vector<int32_t>& intermediates) { |
1213 | const auto* dialect = inst->getDialect(); |
1214 | if (!dialect) { |
1215 | inst->emitOpError("dialect is not registered" ); |
1216 | return llvm::None; |
1217 | } |
1218 | |
1219 | // If TFLite built in op, create operator as a builtin op. |
1220 | if (dialect == tfl_dialect_) { |
1221 | // Only if built-in TFLite op emission is enabled, would legalization have |
1222 | // converted any TF->TFL. |
1223 | if (!enabled_op_types_.contains(OpType::kTfliteBuiltin)) { |
1224 | return inst->emitOpError( |
1225 | "is a TFLite builtin op but builtin emission is not enabled" ), |
1226 | llvm::None; |
1227 | } |
1228 | |
1229 | auto builtin_code = GetBuiltinOpCode(inst); |
1230 | if (!builtin_code) { |
1231 | if (auto verify_op = dyn_cast<mlir::TFL::NumericVerifyOp>(inst)) { |
1232 | return BuildNumericVerifyOperator(verify_op, operands, results); |
1233 | } |
1234 | if (auto custom_op = dyn_cast<mlir::TFL::CustomOp>(inst)) { |
1235 | return BuildCustomOperator(inst, custom_op, operands, results); |
1236 | } |
1237 | if (auto whileOp = dyn_cast<mlir::TFL::WhileOp>(inst)) { |
1238 | if (inst->getNumOperands() != inst->getNumResults()) { |
1239 | inst->emitOpError( |
1240 | "number of operands and results don't match, only canonical " |
1241 | "TFL While supported" ); |
1242 | return llvm::None; |
1243 | } |
1244 | return BuildWhileOperator(whileOp, operands, results); |
1245 | } |
1246 | |
1247 | inst->emitOpError("is not a supported TFLite op" ); |
1248 | return llvm::None; |
1249 | } |
1250 | |
1251 | if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) { |
1252 | if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) { |
1253 | return BuildCallOnceOperator(initOp, operands, results); |
1254 | } |
1255 | } |
1256 | |
1257 | std::string op_name = inst->getName().getStringRef().str(); |
1258 | uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code); |
1259 | |
1260 | // If this is TransposeConv we need to do a special case of ignoring the |
1261 | // optional tensor, to allow newly created models to run on old runtimes. |
1262 | if (*builtin_code == tflite::BuiltinOperator_TRANSPOSE_CONV) { |
1263 | if (operands.size() == 4 && operands.at(3) == -1) { |
1264 | operands.pop_back(); |
1265 | } |
1266 | } |
1267 | |
1268 | auto offset = CreateFlatBufferOperator(inst, opcode_index, operands, |
1269 | results, intermediates, &builder_); |
1270 | if (!offset) { |
1271 | inst->emitOpError("is not a supported TFLite op" ); |
1272 | } |
1273 | return offset; |
1274 | } |
1275 | |
1276 | if (dialect == tf_dialect_) { |
1277 | if (auto ifOp = dyn_cast<mlir::TF::IfOp>(inst)) { |
1278 | return BuildIfOperator(ifOp, operands, results); |
1279 | } |
1280 | |
1281 | CustomOptionsOffset custom_options; |
1282 | |
1283 | // Ops in TF dialect can either be custom ops or flex ops. |
1284 | // The reason we go directly from TensorFlow dialect MLIR to tensorflow |
1285 | // node instead of going to TF table gen'd ops via generated code is that |
1286 | // we do not want to restrict custom and flex op conversion support to |
1287 | // only those TF ops that are currently registered in MLIR. The current |
1288 | // model is of an open op system. |
1289 | // |
1290 | // The following algorithm is followed: |
1291 | // if flex is enabled and the op is allowlisted as flex |
1292 | // we emit op as flex. |
1293 | // if custom is enabled |
1294 | // we emit the op as custom. |
1295 | auto node_def = GetTensorFlowNodeDef(inst); |
1296 | if (!node_def) { |
1297 | return llvm::None; |
1298 | } |
1299 | |
1300 | std::string op_name = node_def->op(); |
1301 | std::string op_desc = GetOpDescriptionForDebug(inst); |
1302 | |
1303 | if (IsTFResourceOp(inst)) { |
1304 | resource_ops_[op_name].insert(op_desc); |
1305 | } |
1306 | |
1307 | const bool is_allowed_flex_op = |
1308 | !IsUnsupportedFlexOp(node_def->op()) && |
1309 | (IsAllowlistedFlexOp(node_def->op()) || |
1310 | (((select_user_tf_ops_.count(node_def->op()) != 0) || |
1311 | allow_all_select_tf_ops_) && |
1312 | (tensorflow::OpRegistry::Global()->LookUp(node_def->op()) != |
1313 | nullptr))); |
1314 | |
1315 | // Flex op case |
1316 | // Eventually, the allowlist will go away and we will rely on some TF op |
1317 | // trait (e.g. No side effect) to determine if it is a supported "Flex" |
1318 | // op or not. |
1319 | if (is_allowed_flex_op && enabled_op_types_.contains(OpType::kSelectTf)) { |
1320 | // Construct ops as flex op encoding TensorFlow node definition |
1321 | // as custom options. |
1322 | // Flex ops are named with the kFlexOpNamePrefix prefix to the actual |
1323 | // TF op name. |
1324 | op_name = std::string(kFlexOpNamePrefix) + node_def->op(); |
1325 | if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { |
1326 | custom_options = *options; |
1327 | } else { |
1328 | return llvm::None; |
1329 | } |
1330 | |
1331 | // Gather flex ops. |
1332 | flex_ops_[op_name].insert(op_desc); |
1333 | } else if (enabled_op_types_.contains(OpType::kCustomOp)) { |
1334 | // Generic case of custom ops - write using flex buffers since that |
1335 | // is the only custom options supported by TFLite today. |
1336 | op_name = node_def->op(); |
1337 | if (auto options = |
1338 | CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { |
1339 | custom_options = *options; |
1340 | } else { |
1341 | return llvm::None; |
1342 | } |
1343 | |
1344 | // Gather custom ops. |
1345 | custom_ops_[op_name].insert(op_desc); |
1346 | } else { |
1347 | // Insert failed op to `flex_ops` or `custom_ops`. |
1348 | if (is_allowed_flex_op) { |
1349 | failed_flex_ops_[op_name].insert(op_desc); |
1350 | tfl::AttachErrorCode( |
1351 | inst->emitOpError("is neither a custom op nor a flex op" ), |
1352 | tflite::metrics::ConverterErrorData::ERROR_NEEDS_FLEX_OPS); |
1353 | } else { |
1354 | failed_custom_ops_[op_name].insert(op_desc); |
1355 | tfl::AttachErrorCode( |
1356 | inst->emitOpError("is neither a custom op nor a flex op" ), |
1357 | tflite::metrics::ConverterErrorData::ERROR_NEEDS_CUSTOM_OPS); |
1358 | } |
1359 | return llvm::None; |
1360 | } |
1361 | |
1362 | uint32_t opcode_index = |
1363 | GetOpcodeIndex(op_name, tflite::BuiltinOperator_CUSTOM); |
1364 | auto inputs = builder_.CreateVector(operands); |
1365 | auto outputs = builder_.CreateVector(results); |
1366 | |
1367 | return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, |
1368 | tflite::BuiltinOptions_NONE, |
1369 | /*builtin_options=*/0, |
1370 | /*custom_options=*/custom_options, |
1371 | tflite::CustomOptionsFormat_FLEXBUFFERS, |
1372 | /*mutating_variable_inputs=*/0); |
1373 | } |
1374 | |
1375 | return inst->emitOpError( |
1376 | "is not any of a builtin TFLite op, a flex TensorFlow op or a " |
1377 | "custom TensorFlow op" ), |
1378 | llvm::None; |
1379 | } |
1380 | |
1381 | void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { |
1382 | auto dict_attr = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function" ); |
1383 | if (!dict_attr) return; |
1384 | |
1385 | llvm::SmallVector<llvm::StringRef, 2> input_names; |
1386 | llvm::SmallVector<llvm::StringRef, 2> output_names; |
1387 | if (auto str = dict_attr.get("inputs" ).dyn_cast_or_null<mlir::StringAttr>()) { |
1388 | str.getValue().split(input_names, ',', /*MaxSplit=*/-1, |
1389 | /*KeepEmpty=*/false); |
1390 | if (input_names.size() != fn.getNumArguments()) { |
1391 | fn.emitWarning() << "invalid entry function specification" ; |
1392 | return; |
1393 | } |
1394 | for (const auto& it : llvm::enumerate(fn.getArguments())) { |
1395 | name_mapper_.InitOpName(it.value(), input_names[it.index()].trim()); |
1396 | } |
1397 | *has_input_attr = true; |
1398 | } |
1399 | |
1400 | if (auto str = |
1401 | dict_attr.get("outputs" ).dyn_cast_or_null<mlir::StringAttr>()) { |
1402 | str.getValue().split(output_names, ',', /*MaxSplit=*/-1, |
1403 | /*KeepEmpty=*/false); |
1404 | auto term = fn.back().getTerminator(); |
1405 | if (output_names.size() != term->getNumOperands()) { |
1406 | fn.emitWarning() << "output names (" << output_names.size() |
1407 | << ") != terminator operands (" << term->getNumOperands() |
1408 | << ")" ; |
1409 | return; |
1410 | } |
1411 | for (const auto& it : llvm::enumerate(term->getOperands())) { |
1412 | name_mapper_.InitOpName(it.value(), output_names[it.index()].trim()); |
1413 | } |
1414 | } |
1415 | } |
1416 | |
1417 | bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { |
1418 | std::vector<int> operand_indices; |
1419 | if (!mlir::TFL::IsStatefulOp(op, &operand_indices)) return false; |
1420 | return absl::c_find(operand_indices, operand_index) != operand_indices.end(); |
1421 | } |
1422 | |
1423 | BufferOffset<tflite::QuantizationParameters> |
1424 | Translator::GetQuantizationForQuantStatsOpOutput( |
1425 | mlir::quantfork::StatisticsOp stats_op) { |
1426 | auto layer_stats = stats_op.getLayerStats().cast<mlir::DenseFPElementsAttr>(); |
1427 | Optional<mlir::ElementsAttr> axis_stats = stats_op.getAxisStats(); |
1428 | Optional<uint64_t> axis = stats_op.getAxis(); |
1429 | std::vector<float> mins, maxs; |
1430 | mlir::DenseFPElementsAttr min_max_attr = |
1431 | axis_stats.has_value() |
1432 | ? axis_stats.getValue().cast<mlir::DenseFPElementsAttr>() |
1433 | : layer_stats; |
1434 | |
1435 | for (const auto& index_and_value : |
1436 | llvm::enumerate(min_max_attr.getValues<llvm::APFloat>())) { |
1437 | const llvm::APFloat value = index_and_value.value(); |
1438 | if (index_and_value.index() % 2 == 0) { |
1439 | mins.push_back(value.convertToFloat()); |
1440 | } else { |
1441 | maxs.push_back(value.convertToFloat()); |
1442 | } |
1443 | } |
1444 | |
1445 | return tflite::CreateQuantizationParameters( |
1446 | builder_, builder_.CreateVector<float>(mins), |
1447 | builder_.CreateVector<float>(maxs), /*scale=*/0, /*zero_point=*/0, |
1448 | tflite::QuantizationDetails_NONE, /*details=*/0, |
1449 | /*quantized_dimension=*/axis.has_value() ? axis.getValue() : 0); |
1450 | } |
1451 | |
1452 | Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph( |
1453 | const std::string& name, Region* region, const int index) { |
1454 | const auto control_edges = ExtractControlEdges(®ion->front()); |
1455 | bool has_input_attr = false; |
1456 | if (auto fn = dyn_cast<FuncOp>(region->getParentOp())) { |
1457 | InitializeNamesFromAttribute(fn, &has_input_attr); |
1458 | } |
1459 | std::vector<BufferOffset<tflite::Tensor>> tensors; |
1460 | llvm::DenseMap<Value, int> tensor_index_map; |
1461 | |
1462 | // Builds tensor and buffer for argument or operation result. Returns false |
1463 | // on failure. |
1464 | auto build_tensor_and_buffer = [&](Value value, const int subgraph_index, |
1465 | const std::string& tensor_name) { |
1466 | // NoneType represents optional and may be skipped here. |
1467 | if (value.getType().isa<NoneType>()) { |
1468 | return true; |
1469 | } |
1470 | |
1471 | tensor_index_map.insert({value, tensors.size()}); |
1472 | tensor_index_map_[subgraph_index][tensor_name] = tensors.size(); |
1473 | Optional<BufferOffset<tflite::QuantizationParameters>> quant_parameters; |
1474 | if (value.hasOneUse()) { |
1475 | auto stats_op = |
1476 | llvm::dyn_cast<mlir::quantfork::StatisticsOp>(*value.user_begin()); |
1477 | if (stats_op) { |
1478 | quant_parameters = GetQuantizationForQuantStatsOpOutput(stats_op); |
1479 | } |
1480 | } |
1481 | auto tensor_or = |
1482 | BuildTensor(value, tensor_name, buffers_.size(), quant_parameters); |
1483 | if (!tensor_or) return false; |
1484 | tensors.push_back(*tensor_or); |
1485 | |
1486 | // TODO(ashwinm): Check if for stateful tensors, if it is also needed to |
1487 | // make the Buffer empty apart from setting the buffer_idx=0 in the |
1488 | // Tensor. This does not seem to affect runtime behavior for RNN/LSTM, |
1489 | // but would be good for reducing memory footprint. |
1490 | if (auto* inst = value.getDefiningOp()) { |
1491 | auto buffer_or = BuildBuffer(inst); |
1492 | if (!buffer_or) return false; |
1493 | buffers_.push_back(*buffer_or); |
1494 | } else { |
1495 | buffers_.push_back(empty_buffer_); |
1496 | } |
1497 | return true; |
1498 | }; |
1499 | |
1500 | std::vector<BufferOffset<tflite::Operator>> operators; |
1501 | |
1502 | // Maps positions of operations in bb to positions in operators |
1503 | llvm::DenseMap<int, int> operation_index_to_operator_index; |
1504 | std::vector<Operation*> operators_in_mlir; |
1505 | auto& bb = region->front(); |
1506 | |
1507 | // Main function's arguments are first passed to `input` op so they don't |
1508 | // have associated tensor and buffer. Build FlatBuffer tensor and buffer for |
1509 | // other functions. |
1510 | for (unsigned i = 0, e = bb.getNumArguments(); i < e; ++i) { |
1511 | mlir::BlockArgument arg = bb.getArgument(i); |
1512 | std::string tensor_name; |
1513 | if (has_input_attr) |
1514 | tensor_name = std::string(name_mapper_.GetUniqueName(arg)); |
1515 | if (tensor_name.empty()) tensor_name = absl::StrCat("arg" , i); |
1516 | if (!build_tensor_and_buffer(arg, index, tensor_name)) return llvm::None; |
1517 | } |
1518 | |
1519 | bool failed_once = false; |
1520 | for (auto& item : llvm::enumerate(bb)) { |
1521 | Operation& inst = item.value(); |
1522 | const int operation_index = item.index(); |
1523 | if (inst.hasTrait<mlir::OpTrait::IsTerminator>()) break; |
1524 | // For "quant.stats" op, it's used to store the quantization parameters info |
1525 | // and its output should be then replaced by its input value. |
1526 | if (auto quant_stats_op = |
1527 | llvm::dyn_cast<mlir::quantfork::StatisticsOp>(inst)) { |
1528 | continue; |
1529 | } |
1530 | std::vector<int32_t> intermediates; |
1531 | // Build intermediate tensors for tfl.lstm and insert these tensors into |
1532 | // flatbuffer. |
1533 | if (llvm::isa<mlir::TFL::LSTMOp, mlir::TFL::UnidirectionalSequenceLSTMOp>( |
1534 | inst)) { |
1535 | std::vector<std::string> intermediate_names = { |
1536 | "input_to_input_intermediate" , "input_to_forget_intermediate" , |
1537 | "input_to_cell_intermediate" , "input_to_output_intermediate" , |
1538 | "effective_hidden_scale_intermediate" }; |
1539 | for (const std::string& intermediate : intermediate_names) { |
1540 | auto intermediate_attr = inst.getAttr(intermediate); |
1541 | if (auto attr = intermediate_attr.dyn_cast_or_null<mlir::TypeAttr>()) { |
1542 | Type qtype = attr.getValue(); |
1543 | auto tensor_or = BuildTensorFromType( |
1544 | qtype, name_mapper_.GetUniqueName(intermediate).str()); |
1545 | if (!tensor_or.has_value()) { |
1546 | continue; |
1547 | } else { |
1548 | intermediates.push_back(tensors.size()); |
1549 | tensors.push_back(tensor_or.getValue()); |
1550 | } |
1551 | } |
1552 | } |
1553 | } |
1554 | |
1555 | for (auto val : inst.getResults()) { |
1556 | std::string tensor_name = UniqueName(val); |
1557 | // For "tfl.numeric_verify" op, the name is used to find out the original |
1558 | // activation tensor rather than its own unique name in the visualization |
1559 | // or debugging tools. |
1560 | auto builtin_code = GetBuiltinOpCode(&inst); |
1561 | if (!builtin_code && dyn_cast<mlir::TFL::NumericVerifyOp>(&inst)) { |
1562 | // The first operand is the quantized activation, the target of this |
1563 | // NumericVerify op. |
1564 | auto quantized_op_val = inst.getOperands().front(); |
1565 | tensor_name = "NumericVerify/" + UniqueName(quantized_op_val) + ":" + |
1566 | std::to_string(tensor_index_map[quantized_op_val]); |
1567 | } |
1568 | if (!build_tensor_and_buffer(val, index, tensor_name)) return llvm::None; |
1569 | } |
1570 | |
1571 | // Skip constant ops as they don't represent a TFLite operator. |
1572 | if (IsConst(&inst)) continue; |
1573 | |
1574 | // Fetch operand and result tensor indices. |
1575 | std::vector<int32_t> results; |
1576 | results.reserve(inst.getNumResults()); |
1577 | for (auto result : inst.getResults()) { |
1578 | results.push_back(tensor_index_map.lookup(result)); |
1579 | } |
1580 | Operation* real_inst = &inst; |
1581 | std::vector<int32_t> operands; |
1582 | operands.reserve(real_inst->getNumOperands()); |
1583 | for (auto operand : real_inst->getOperands()) { |
1584 | if (operand.getType().isa<NoneType>()) |
1585 | operands.push_back(kTfLiteOptionalTensor); |
1586 | else if (auto stats_op = |
1587 | llvm::dyn_cast_or_null<mlir::quantfork::StatisticsOp>( |
1588 | operand.getDefiningOp())) |
1589 | operands.push_back(tensor_index_map.lookup(stats_op.getArg())); |
1590 | else |
1591 | operands.push_back(tensor_index_map.lookup(operand)); |
1592 | } |
1593 | |
1594 | // CustomTfOp is just a wrapper around a TF op, we export the custom Op |
1595 | // not the wrapper, so we fetch the op from the region. |
1596 | if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) { |
1597 | // If we have custom op with a region, then use the first op in the |
1598 | // region, if it exists, otherwise just use params for custom op. |
1599 | if (!custom_op.body().empty()) { |
1600 | real_inst = &custom_op.body().front().front(); |
1601 | } else { |
1602 | module_.emitError( |
1603 | "Invalid CustomTfOp: Custom TF Op have empty region." ); |
1604 | } |
1605 | } |
1606 | if (auto tfl_operator = |
1607 | BuildOperator(real_inst, operands, results, intermediates)) { |
1608 | operation_index_to_operator_index.try_emplace(operation_index, |
1609 | operators.size()); |
1610 | operators.push_back(*tfl_operator); |
1611 | operators_in_mlir.push_back(real_inst); |
1612 | } else { |
1613 | failed_once = true; |
1614 | } |
1615 | } |
1616 | if (index + 1 > subgraph_op_inst_map_.size()) { |
1617 | subgraph_op_inst_map_.resize(index + 1); |
1618 | } |
1619 | subgraph_op_inst_map_[index] = operators_in_mlir; |
1620 | if (failed_once) return llvm::None; |
1621 | |
1622 | // Get input and output tensor indices for the subgraph. |
1623 | std::vector<int32_t> inputs, outputs; |
1624 | for (auto arg : bb.getArguments()) { |
1625 | inputs.push_back(tensor_index_map[arg]); |
1626 | } |
1627 | for (auto result : bb.getTerminator()->getOperands()) { |
1628 | outputs.push_back(tensor_index_map[result]); |
1629 | } |
1630 | for (const auto& [from, to] : control_edges) { |
1631 | for (int what : {from, to}) { |
1632 | if (operation_index_to_operator_index.count(what) == 0) { |
1633 | module_.emitError( |
1634 | "dangling control edge -- at least one vertex Operation isn't a " |
1635 | "flatbuffer Operator." ); |
1636 | } |
1637 | } |
1638 | model_control_dependencies_[index].emplace_back( |
1639 | operation_index_to_operator_index[from], |
1640 | operation_index_to_operator_index[to]); |
1641 | } |
1642 | return tflite::CreateSubGraph( |
1643 | builder_, builder_.CreateVector(tensors), builder_.CreateVector(inputs), |
1644 | builder_.CreateVector(outputs), builder_.CreateVector(operators), |
1645 | /*name=*/builder_.CreateString(name)); |
1646 | } |
1647 | |
1648 | BufferOffset<tflite::Metadata> Translator::BuildMetadata(StringRef name, |
1649 | StringRef content) { |
1650 | auto buffer_index = buffers_.size(); |
1651 | auto buffer_data = builder_.CreateVector( |
1652 | reinterpret_cast<const uint8_t*>(content.data()), content.size()); |
1653 | buffers_.push_back(tflite::CreateBuffer(builder_, buffer_data)); |
1654 | return tflite::CreateMetadataDirect(builder_, name.data(), buffer_index); |
1655 | } |
1656 | |
1657 | Optional<VectorBufferOffset<BufferOffset<tflite::Metadata>>> |
1658 | Translator::CreateMetadataVector() { |
1659 | auto dict_attr = module_->getAttrOfType<mlir::DictionaryAttr>("tfl.metadata" ); |
1660 | std::vector<BufferOffset<tflite::Metadata>> metadata; |
1661 | if (dict_attr) { |
1662 | for (const auto& named_attr : dict_attr) { |
1663 | StringRef name = named_attr.getName(); |
1664 | mlir::Attribute attr = named_attr.getValue(); |
1665 | if (auto content = attr.dyn_cast<StringAttr>()) { |
1666 | metadata.push_back(BuildMetadata(name, content.getValue())); |
1667 | } else { |
1668 | module_.emitError( |
1669 | "all values in tfl.metadata's dictionary key-value pairs should be " |
1670 | "string attributes" ); |
1671 | return llvm::None; |
1672 | } |
1673 | } |
1674 | } |
1675 | // Runtime version string is generated after we update the op |
1676 | // versions. Here we put a 16-byte dummy string as a placeholder. We choose |
1677 | // 16-byte because it's the alignment of buffers in flatbuffer, so it won't |
1678 | // cause any waste of space if the actual string is shorter than 16 bytes. |
1679 | constexpr std::size_t kByteStringSize = 16; |
1680 | metadata.push_back( |
1681 | BuildMetadata("min_runtime_version" , std::string(kByteStringSize, '\0'))); |
1682 | for (const auto& kv : metadata_) { |
1683 | const std::string& val = kv.second; |
1684 | // Only take the first kByteStringSize values. |
1685 | const int count = std::min(kByteStringSize, val.length()); |
1686 | std::string value = std::string(kByteStringSize, '\0') |
1687 | .assign(val.begin(), val.begin() + count); |
1688 | metadata.push_back(BuildMetadata(kv.first, value)); |
1689 | } |
1690 | |
1691 | // Populate the model control dependencies metadata entry. |
1692 | if (std::any_of( |
1693 | model_control_dependencies_.begin(), |
1694 | model_control_dependencies_.end(), |
1695 | [](const tflite::ControlEdges& edges) { return !edges.empty(); })) { |
1696 | metadata.push_back( |
1697 | BuildMetadata(tflite::kModelControlDependenciesMetadataKey, |
1698 | tflite::SerializeModelControlDependencies( |
1699 | model_control_dependencies_))); |
1700 | } |
1701 | return builder_.CreateVector(metadata); |
1702 | } |
1703 | |
1704 | // Helper method that returns list of all strings in a StringAttr identified |
1705 | // by 'attr_key' and values are separated by a comma. |
1706 | llvm::SmallVector<llvm::StringRef, 2> GetStringsFromAttrWithSeparator( |
1707 | mlir::DictionaryAttr attr, const std::string& attr_key) { |
1708 | llvm::SmallVector<llvm::StringRef, 2> result; |
1709 | if (auto str = attr.get(attr_key).dyn_cast_or_null<mlir::StringAttr>()) { |
1710 | str.getValue().split(result, ',', /*MaxSplit=*/-1, |
1711 | /*KeepEmpty=*/false); |
1712 | } |
1713 | return result; |
1714 | } |
1715 | |
1716 | // Helper method that return list of string for all the StringAttr in the |
1717 | // Attribute identified by 'attr_name'. |
1718 | std::vector<std::string> GetStringsFromDictionaryAttr( |
1719 | const llvm::SmallVector<mlir::DictionaryAttr, 4>& dict_attrs, |
1720 | const std::string& attr_name) { |
1721 | std::vector<std::string> result; |
1722 | for (const auto& arg_attr : dict_attrs) { |
1723 | if (!arg_attr) continue; |
1724 | |
1725 | auto attrs = arg_attr.getValue(); |
1726 | for (const auto attr : attrs) { |
1727 | if (attr.getName().str() == attr_name) { |
1728 | auto array_attr = attr.getValue().dyn_cast_or_null<mlir::ArrayAttr>(); |
1729 | if (!array_attr || array_attr.empty()) continue; |
1730 | auto string_attr = array_attr[0].dyn_cast_or_null<mlir::StringAttr>(); |
1731 | if (!string_attr) continue; |
1732 | result.push_back(string_attr.getValue().str()); |
1733 | } |
1734 | } |
1735 | } |
1736 | return result; |
1737 | } |
1738 | |
1739 | std::vector<SignatureDefData> BuildSignaturedef( |
1740 | FuncOp main_op, const std::string& saved_model_tag, |
1741 | const uint32_t subgraph_index, tensorflow::OpOrArgNameMapper& name_mapper) { |
1742 | static const char kSignatureDefIndexPath[] = "tf_saved_model.index_path" ; |
1743 | static const char kEntryFunctionAttributes[] = "tf.entry_function" ; |
1744 | |
1745 | // Fetch inputs and outputs from the signature. |
1746 | llvm::SmallVector<mlir::DictionaryAttr, 4> arg_attrs, res_attrs; |
1747 | main_op.getAllArgAttrs(arg_attrs); |
1748 | main_op.getAllResultAttrs(res_attrs); |
1749 | std::vector<std::string> sig_def_inputs = |
1750 | GetStringsFromDictionaryAttr(arg_attrs, kSignatureDefIndexPath); |
1751 | std::vector<std::string> sig_def_outputs = |
1752 | GetStringsFromDictionaryAttr(res_attrs, kSignatureDefIndexPath); |
1753 | |
1754 | // If no defined saved model signature, then return empty list. |
1755 | // This can happen when we are converting model not from SavedModel. |
1756 | if (sig_def_inputs.empty() && sig_def_outputs.empty()) return {}; |
1757 | |
1758 | // Fetch function inputs and outputs tensor names. |
1759 | auto dict_attr = |
1760 | main_op->getAttrOfType<mlir::DictionaryAttr>(kEntryFunctionAttributes); |
1761 | if (!dict_attr) return {}; |
1762 | |
1763 | // Get Input and output tensor names from attribute. |
1764 | llvm::SmallVector<llvm::StringRef, 2> input_names = |
1765 | GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"inputs" ); |
1766 | llvm::SmallVector<llvm::StringRef, 2> output_names = |
1767 | GetStringsFromAttrWithSeparator(dict_attr, /*attr_key=*/"outputs" ); |
1768 | |
1769 | // Verify input size match the number of arguments. |
1770 | if (input_names.size() != main_op.getNumArguments()) { |
1771 | main_op.emitWarning() << "invalid entry function specification" ; |
1772 | return {}; |
1773 | } |
1774 | // Verify output size match the number of arguments. |
1775 | auto term = main_op.back().getTerminator(); |
1776 | if (output_names.size() != term->getNumOperands()) { |
1777 | main_op.emitWarning() << "output names (" << output_names.size() |
1778 | << ") != terminator operands (" |
1779 | << term->getNumOperands() << ")" ; |
1780 | return {}; |
1781 | } |
1782 | // Verify number of tensors for inputs and outputs matches size |
1783 | // of the list in the signature def. |
1784 | if (input_names.size() != sig_def_inputs.size() || |
1785 | output_names.size() != sig_def_outputs.size()) { |
1786 | main_op.emitWarning( |
1787 | "Mismatch between signature def inputs/outputs and main function " |
1788 | "arguments." ); |
1789 | return {}; |
1790 | } |
1791 | // Exported method name. |
1792 | auto exported_name = |
1793 | main_op->getAttrOfType<mlir::ArrayAttr>("tf_saved_model.exported_names" ); |
1794 | if (exported_name.empty()) { |
1795 | main_op.emitError("Empty exported names for main Function" ); |
1796 | return {}; |
1797 | } |
1798 | // Fill the SignatureDefData container. |
1799 | // We create vector of size 1 as TFLite now supports only 1 signatureDef. |
1800 | std::vector<SignatureDefData> result(1); |
1801 | for (int i = 0; i < input_names.size(); ++i) { |
1802 | result[0].inputs[sig_def_inputs[i]] = input_names[i].str(); |
1803 | } |
1804 | for (int i = 0; i < output_names.size(); ++i) { |
1805 | // Fetch the name from the actual operand and not rely on names from |
1806 | // outputs as deduping can make them invalid after conversion. |
1807 | auto& operand = term->getOpOperand(i); |
1808 | auto unique_name = std::string(name_mapper.GetUniqueName(operand.get())); |
1809 | result[0].outputs[sig_def_outputs[i]] = unique_name; |
1810 | } |
1811 | if (auto name_attr = exported_name[0].dyn_cast_or_null<StringAttr>()) |
1812 | result[0].signature_key = name_attr.getValue().str(); |
1813 | result[0].subgraph_index = subgraph_index; |
1814 | return result; |
1815 | } |
1816 | |
1817 | std::vector<BufferOffset<tflite::TensorMap>> Translator::GetList( |
1818 | const int subgraph_index, const std::map<std::string, std::string>& items) { |
1819 | std::vector<BufferOffset<tflite::TensorMap>> result; |
1820 | for (const auto& item : items) { |
1821 | auto name_buf = builder_.CreateString(item.first); |
1822 | tflite::TensorMapBuilder tensor_map_builder(builder_); |
1823 | tensor_map_builder.add_name(name_buf); |
1824 | tensor_map_builder.add_tensor_index( |
1825 | tensor_index_map_[subgraph_index][item.second]); |
1826 | result.push_back(tensor_map_builder.Finish()); |
1827 | } |
1828 | return result; |
1829 | } |
1830 | |
1831 | Optional<VectorBufferOffset<BufferOffset<tflite::SignatureDef>>> |
1832 | Translator::CreateSignatureDefs( |
1833 | const std::vector<SignatureDefData>& signature_defs) { |
1834 | std::vector<BufferOffset<tflite::SignatureDef>> signature_defs_buffer; |
1835 | // When we export each function in the module op, intentionally, we export the |
1836 | // entry functions at the beginning of the subgraph list and the |
1837 | // subgraph_index is the index in entry functions and at the same, is the |
1838 | // index in the subgraph list. |
1839 | int subgraph_index = 0; |
1840 | for (const auto& signature_def_data : signature_defs) { |
1841 | auto inputs = GetList(subgraph_index, signature_def_data.inputs); |
1842 | auto outputs = GetList(subgraph_index, signature_def_data.outputs); |
1843 | auto inputs_buf = builder_.CreateVector(inputs); |
1844 | auto outputs_buf = builder_.CreateVector(outputs); |
1845 | auto signature_key_buf = |
1846 | builder_.CreateString(signature_def_data.signature_key); |
1847 | tflite::SignatureDefBuilder sig_def_builder(builder_); |
1848 | sig_def_builder.add_inputs(inputs_buf); |
1849 | sig_def_builder.add_outputs(outputs_buf); |
1850 | sig_def_builder.add_signature_key(signature_key_buf); |
1851 | sig_def_builder.add_subgraph_index(signature_def_data.subgraph_index); |
1852 | signature_defs_buffer.push_back(sig_def_builder.Finish()); |
1853 | ++subgraph_index; |
1854 | } |
1855 | |
1856 | return builder_.CreateVector(signature_defs_buffer); |
1857 | } |
1858 | |
1859 | bool UpdateEntryFunction(ModuleOp module) { |
1860 | if (module.lookupSymbol<FuncOp>("main" ) != nullptr) { |
1861 | // We already have an entry function. |
1862 | return true; |
1863 | } |
1864 | |
1865 | int entry_func_count = 0; |
1866 | FuncOp entry_func = nullptr; |
1867 | for (auto fn : module.getOps<FuncOp>()) { |
1868 | auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function" ); |
1869 | if (!attrs || attrs.empty()) continue; |
1870 | ++entry_func_count; |
1871 | entry_func = fn; |
1872 | } |
1873 | |
1874 | // We should have at least one entry function. |
1875 | if (entry_func_count == 0) return false; |
1876 | |
1877 | if (entry_func_count == 1) { |
1878 | // Update the entry func to main when the entry func is only & one. |
1879 | entry_func.setName(StringAttr::get(module.getContext(), "main" )); |
1880 | } |
1881 | return true; |
1882 | } |
1883 | |
1884 | Optional<std::string> Translator::Translate( |
1885 | ModuleOp module, const toco::TocoFlags& toco_flags, |
1886 | const std::unordered_set<std::string>& tags, |
1887 | OpOrArgNameMapper* op_or_arg_name_mapper, |
1888 | const std::map<std::string, std::string>& metadata) { |
1889 | OpOrArgLocNameMapper default_op_or_arg_name_mapper; |
1890 | if (!op_or_arg_name_mapper) |
1891 | op_or_arg_name_mapper = &default_op_or_arg_name_mapper; |
1892 | if (!UpdateEntryFunction(module)) return llvm::None; |
1893 | if (!IsValidTFLiteMlirModule(module)) return llvm::None; |
1894 | Translator translator(module, toco_flags, tags, op_or_arg_name_mapper, |
1895 | metadata); |
1896 | return translator.TranslateInternal(); |
1897 | } |
1898 | |
1899 | bool Translator::CheckGpuDelegateCompatibility(uint8_t* model_buffer_pointer) { |
1900 | bool gpu_compatibile = true; |
1901 | auto model = tflite::GetModel(model_buffer_pointer); |
1902 | auto subgraphs = model->subgraphs(); |
1903 | |
1904 | for (int i = 0; i < subgraphs->Length(); ++i) { |
1905 | const tflite::SubGraph* subgraph = subgraphs->Get(i); |
1906 | for (int j = 0; j < subgraph->operators()->Length(); ++j) { |
1907 | const tflite::Operator* op = subgraph->operators()->Get(j); |
1908 | const tflite::OperatorCode* op_code = |
1909 | model->operator_codes()->Get(op->opcode_index()); |
1910 | auto status = |
1911 | tflite::CheckGpuDelegateCompatibility(op_code, op, subgraph, model); |
1912 | if (!status.ok()) { |
1913 | gpu_compatibile = false; |
1914 | auto inst = subgraph_op_inst_map_[i][j]; |
1915 | tfl::AttachErrorCode( |
1916 | inst->emitOpError() |
1917 | << "is not GPU compatible: " << std::string(status.message()), |
1918 | tflite::metrics::ConverterErrorData::ERROR_GPU_NOT_COMPATIBLE); |
1919 | } |
1920 | } |
1921 | } |
1922 | return gpu_compatibile; |
1923 | } |
1924 | |
1925 | Optional<std::string> Translator::TranslateInternal() { |
1926 | // A list of named regions in the module with main function being the first in |
1927 | // the list. The main function is required as the first subgraph in the model |
1928 | // is entry point for the model. |
1929 | std::vector<std::pair<std::string, Region*>> named_regions; |
1930 | named_regions.reserve(std::distance(module_.begin(), module_.end())); |
1931 | |
1932 | int subgraph_idx = 0; |
1933 | |
1934 | // Entry functions for signature defs. |
1935 | std::vector<FuncOp> entry_functions; |
1936 | std::vector<FuncOp> non_entry_functions; |
1937 | FuncOp main_fn = module_.lookupSymbol<FuncOp>("main" ); |
1938 | if (main_fn != nullptr) { |
1939 | // Treat the main function as a signature def when the given main function |
1940 | // contains on the tf.entry_function attribute. |
1941 | auto attrs = |
1942 | main_fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function" ); |
1943 | if (attrs && !attrs.empty()) { |
1944 | entry_functions.push_back(main_fn); |
1945 | } else { |
1946 | non_entry_functions.push_back(main_fn); |
1947 | } |
1948 | } |
1949 | |
1950 | // Walk over the module collection ops with functions and while ops. |
1951 | module_.walk([&](FuncOp fn) { |
1952 | if (main_fn == fn) return WalkResult::advance(); |
1953 | auto attrs = fn->getAttrOfType<mlir::DictionaryAttr>("tf.entry_function" ); |
1954 | if (attrs && !attrs.empty()) { |
1955 | entry_functions.push_back(fn); |
1956 | } else { |
1957 | non_entry_functions.push_back(fn); |
1958 | } |
1959 | return WalkResult::advance(); |
1960 | }); |
1961 | |
1962 | // Assign the subgraph index. Among the given functions, it will put entry |
1963 | // functions at the beginning of the list of the subgrahs. |
1964 | for (auto fn : entry_functions) { |
1965 | subgraph_index_map_[fn.getName().str()] = subgraph_idx++; |
1966 | named_regions.emplace_back(fn.getName().str(), &fn.getBody()); |
1967 | } |
1968 | for (auto fn : non_entry_functions) { |
1969 | subgraph_index_map_[fn.getName().str()] = subgraph_idx++; |
1970 | named_regions.emplace_back(fn.getName().str(), &fn.getBody()); |
1971 | } |
1972 | |
1973 | // Build subgraph for each of the named regions. |
1974 | std::vector<BufferOffset<tflite::SubGraph>> subgraphs; |
1975 | subgraphs.reserve(named_regions.size()); |
1976 | model_control_dependencies_.assign(named_regions.size(), {}); |
1977 | int first_failed_func = -1; |
1978 | |
1979 | // When we export each function in the module op, intentionally, we export the |
1980 | // entry functions at the beginning of the subgraph list and the |
1981 | // subgraph_index is the index in entry functions and at the same, is the |
1982 | // index in the subgraph list. |
1983 | int subgraph_index = 0; |
1984 | for (const auto& it : llvm::enumerate(named_regions)) { |
1985 | auto subgraph_or = |
1986 | BuildSubGraph(it.value().first, it.value().second, subgraph_index); |
1987 | if (!subgraph_or) { |
1988 | if (first_failed_func == -1) |
1989 | // Record the index of the first region that cannot be converted. |
1990 | // Keep looping through all subgraphs in the module to make sure that |
1991 | // we collect the list of missing ops from the entire module. |
1992 | first_failed_func = it.index(); |
1993 | } else { |
1994 | subgraphs.push_back(*subgraph_or); |
1995 | ++subgraph_index; |
1996 | } |
1997 | } |
1998 | |
1999 | if (!resource_ops_.empty()) { |
2000 | std::string resource_ops_summary = |
2001 | GetOpsSummary(resource_ops_, /*summary_title=*/"Resource" ); |
2002 | LOG(WARNING) << "Graph contains the following resource op(s), that use(s) " |
2003 | "resource type. Currently, the " |
2004 | "resource type is not natively supported in TFLite. Please " |
2005 | "consider not using the resource type if there are issues " |
2006 | "with either TFLite converter or TFLite runtime:\n" |
2007 | << resource_ops_summary; |
2008 | } |
2009 | |
2010 | if (!flex_ops_.empty()) { |
2011 | std::string flex_ops_summary = |
2012 | GetOpsSummary(flex_ops_, /*summary_title=*/"Flex" ); |
2013 | LOG(WARNING) << "TFLite interpreter needs to link Flex delegate in order " |
2014 | "to run the model since it contains the following Select TF" |
2015 | "op(s):\n" |
2016 | << flex_ops_summary |
2017 | << "\nSee instructions: " |
2018 | "https://www.tensorflow.org/lite/guide/ops_select" ; |
2019 | } |
2020 | |
2021 | if (!custom_ops_.empty()) { |
2022 | std::string custom_ops_summary = |
2023 | GetOpsSummary(custom_ops_, /*summary_title=*/"Custom" ); |
2024 | LOG(WARNING) << "The following operation(s) need TFLite custom op " |
2025 | "implementation(s):\n" |
2026 | << custom_ops_summary |
2027 | << "\nSee instructions: " |
2028 | "https://www.tensorflow.org/lite/guide/ops_custom" ; |
2029 | } |
2030 | |
2031 | if (first_failed_func != -1) { |
2032 | std::string failed_flex_ops_summary = |
2033 | GetOpsSummary(failed_flex_ops_, /*summary_title=*/"TF Select" ); |
2034 | std::string failed_custom_ops_summary = |
2035 | GetOpsSummary(failed_custom_ops_, /*summary_title=*/"Custom" ); |
2036 | std::string err; |
2037 | if (!failed_flex_ops_.empty()) |
2038 | err += |
2039 | "\nSome ops are not supported by the native TFLite runtime, you can " |
2040 | "enable TF kernels fallback using TF Select. See instructions: " |
2041 | "https://www.tensorflow.org/lite/guide/ops_select \n" + |
2042 | failed_flex_ops_summary + "\n" ; |
2043 | if (!failed_custom_ops_.empty()) |
2044 | err += |
2045 | "\nSome ops in the model are custom ops, " |
2046 | "See instructions to implement " |
2047 | "custom ops: https://www.tensorflow.org/lite/guide/ops_custom \n" + |
2048 | failed_custom_ops_summary + "\n" ; |
2049 | |
2050 | auto& failed_region = named_regions[first_failed_func]; |
2051 | return failed_region.second->getParentOp()->emitError() |
2052 | << "failed while converting: '" << failed_region.first |
2053 | << "': " << err, |
2054 | llvm::None; |
2055 | } |
2056 | |
2057 | // Log MAC count. |
2058 | int64_t ops_count; |
2059 | if (EstimateArithmeticCount(&ops_count)) { |
2060 | const int64_t million = 1e6; |
2061 | const int64_t billion = 1e9; |
2062 | std::string flops_str; |
2063 | std::string mac_str; |
2064 | if (ops_count < 10000) { |
2065 | flops_str = absl::StrFormat("%ld " , ops_count); |
2066 | mac_str = absl::StrFormat("%ld " , ops_count / 2); |
2067 | } else if (ops_count < billion) { |
2068 | flops_str = |
2069 | absl::StrFormat("%.3f M " , static_cast<double>(ops_count) / million); |
2070 | mac_str = absl::StrFormat("%.3f M " , |
2071 | static_cast<double>(ops_count / 2) / million); |
2072 | } else { |
2073 | flops_str = |
2074 | absl::StrFormat("%.3f G " , static_cast<double>(ops_count) / billion); |
2075 | mac_str = absl::StrFormat("%.3f G " , |
2076 | static_cast<double>(ops_count / 2) / billion); |
2077 | } |
2078 | LOG(INFO) << "Estimated count of arithmetic ops: " << flops_str |
2079 | << " ops, equivalently " << mac_str << " MACs" ; |
2080 | } |
2081 | |
2082 | std::string model_description; |
2083 | if (auto attr = module_->getAttrOfType<StringAttr>("tfl.description" )) { |
2084 | model_description = attr.getValue().str(); |
2085 | } else { |
2086 | model_description = "MLIR Converted." ; |
2087 | } |
2088 | |
2089 | // Build the model and finish the model building process. |
2090 | auto description = builder_.CreateString(model_description.data()); |
2091 | VectorBufferOffset<int32_t> metadata_buffer = 0; // Deprecated |
2092 | auto metadata = CreateMetadataVector(); |
2093 | if (!metadata) return llvm::None; |
2094 | |
2095 | std::vector<SignatureDefData> signature_defs_vec; |
2096 | subgraph_index = 0; |
2097 | // Build SignatureDefs for the tf.entry_function based func ops. |
2098 | for (auto fn : entry_functions) { |
2099 | auto signature_defs = BuildSignaturedef( |
2100 | fn, saved_model_tags_.empty() ? "" : *saved_model_tags_.begin(), |
2101 | subgraph_index, name_mapper_); |
2102 | for (const auto& signature_def : signature_defs) { |
2103 | signature_defs_vec.push_back(signature_def); |
2104 | } |
2105 | // When we export each function in the module op, intentionally, we export |
2106 | // the entry functions at the beginning of the subgraph list and the |
2107 | // subgraph_index is the index in entry functions and at the same, is the |
2108 | // index in the subgraph list. |
2109 | ++subgraph_index; |
2110 | } |
2111 | auto signature_defs = CreateSignatureDefs(signature_defs_vec); |
2112 | |
2113 | auto model = tflite::CreateModel(builder_, TFLITE_SCHEMA_VERSION, |
2114 | builder_.CreateVector(opcodes_), |
2115 | builder_.CreateVector(subgraphs), |
2116 | description, builder_.CreateVector(buffers_), |
2117 | metadata_buffer, *metadata, *signature_defs); |
2118 | tflite::FinishModelBuffer(builder_, model); |
2119 | // There is a limit of 2GB for a flatbuffer. |
2120 | if (builder_.GetSize() > 2147483648) { |
2121 | LOG(ERROR) << "Model size is bigger than 2gb" ; |
2122 | return llvm::None; |
2123 | } |
2124 | tflite::UpdateOpVersion(builder_.GetBufferPointer()); |
2125 | tflite::UpdateMinimumRuntimeVersionForModel(builder_.GetBufferPointer()); |
2126 | if (supported_backends_.find("GPU" ) != supported_backends_.end()) { |
2127 | if (!CheckGpuDelegateCompatibility(builder_.GetBufferPointer())) { |
2128 | return llvm::None; |
2129 | } |
2130 | } |
2131 | |
2132 | // Return serialized string for the built FlatBuffer. |
2133 | return std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()), |
2134 | builder_.GetSize()); |
2135 | } |
2136 | |
2137 | BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters( |
2138 | const mlir::TFL::SparsityParameterAttr& s_attr) { |
2139 | const int dim_size = s_attr.getDimMetadata().size(); |
2140 | std::vector<flatbuffers::Offset<tflite::DimensionMetadata>> fb_dim_metadata( |
2141 | dim_size); |
2142 | for (int i = 0; i < dim_size; i++) { |
2143 | const auto dim_metadata = |
2144 | s_attr.getDimMetadata()[i].dyn_cast<mlir::TFL::DimensionMetadataAttr>(); |
2145 | if (dim_metadata.getFormat().getValue() == |
2146 | mlir::TFL::DimensionType::DENSE) { |
2147 | fb_dim_metadata[i] = tflite::CreateDimensionMetadata( |
2148 | builder_, tflite::DimensionType_DENSE, dim_metadata.getDenseSize()); |
2149 | |
2150 | } else { |
2151 | auto segments = dim_metadata.getSegments(); |
2152 | std::vector<int> vector_segments(segments.size(), 0); |
2153 | for (int j = 0, end = segments.size(); j < end; j++) { |
2154 | vector_segments[j] = segments[j]; |
2155 | } |
2156 | tflite::SparseIndexVector segments_type; |
2157 | BufferOffset<void> array_segments; |
2158 | // The segment array is sorted. |
2159 | // TODO(b/147449640): Clean this up with util functions. |
2160 | int max_of_segments = vector_segments[segments.size() - 1]; |
2161 | if (max_of_segments <= UINT8_MAX) { |
2162 | segments_type = tflite::SparseIndexVector_Uint8Vector; |
2163 | std::vector<uint8_t> uint8_vector(vector_segments.begin(), |
2164 | vector_segments.end()); |
2165 | array_segments = tflite::CreateUint8Vector( |
2166 | builder_, builder_.CreateVector(uint8_vector)) |
2167 | .Union(); |
2168 | } else if (max_of_segments <= UINT16_MAX) { |
2169 | segments_type = tflite::SparseIndexVector_Uint16Vector; |
2170 | std::vector<uint16_t> uint16_vector(vector_segments.begin(), |
2171 | vector_segments.end()); |
2172 | array_segments = tflite::CreateUint16Vector( |
2173 | builder_, builder_.CreateVector(uint16_vector)) |
2174 | .Union(); |
2175 | } else { |
2176 | segments_type = tflite::SparseIndexVector_Int32Vector; |
2177 | array_segments = tflite::CreateInt32Vector( |
2178 | builder_, builder_.CreateVector(vector_segments)) |
2179 | .Union(); |
2180 | } |
2181 | |
2182 | auto indices = dim_metadata.getIndices(); |
2183 | std::vector<int> vector_indices(indices.size(), 0); |
2184 | int max_of_indices = 0; |
2185 | for (int j = 0, end = indices.size(); j < end; j++) { |
2186 | vector_indices[j] = indices[j]; |
2187 | if (vector_indices[j] > max_of_indices) { |
2188 | max_of_indices = vector_indices[j]; |
2189 | } |
2190 | } |
2191 | tflite::SparseIndexVector indices_type; |
2192 | BufferOffset<void> array_indices; |
2193 | if (max_of_indices <= UINT8_MAX) { |
2194 | indices_type = tflite::SparseIndexVector_Uint8Vector; |
2195 | std::vector<uint8_t> uint8_vector(vector_indices.begin(), |
2196 | vector_indices.end()); |
2197 | array_indices = tflite::CreateUint8Vector( |
2198 | builder_, builder_.CreateVector(uint8_vector)) |
2199 | .Union(); |
2200 | } else if (max_of_indices <= UINT16_MAX) { |
2201 | indices_type = tflite::SparseIndexVector_Uint16Vector; |
2202 | std::vector<uint16_t> uint16_vector(vector_indices.begin(), |
2203 | vector_indices.end()); |
2204 | array_indices = tflite::CreateUint16Vector( |
2205 | builder_, builder_.CreateVector(uint16_vector)) |
2206 | .Union(); |
2207 | } else { |
2208 | indices_type = tflite::SparseIndexVector_Int32Vector; |
2209 | array_indices = tflite::CreateInt32Vector( |
2210 | builder_, builder_.CreateVector(vector_indices)) |
2211 | .Union(); |
2212 | } |
2213 | |
2214 | fb_dim_metadata[i] = tflite::CreateDimensionMetadata( |
2215 | builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type, |
2216 | array_segments, indices_type, array_indices); |
2217 | } |
2218 | } |
2219 | |
2220 | std::vector<int> traversal_order(dim_size); |
2221 | for (int i = 0; i < dim_size; i++) { |
2222 | traversal_order[i] = s_attr.getTraversalOrder()[i]; |
2223 | } |
2224 | const int block_map_size = s_attr.getBlockMap().size(); |
2225 | std::vector<int> block_map(block_map_size); |
2226 | for (int i = 0; i < block_map_size; i++) { |
2227 | block_map[i] = s_attr.getBlockMap()[i]; |
2228 | } |
2229 | |
2230 | return tflite::CreateSparsityParameters( |
2231 | builder_, builder_.CreateVector(traversal_order), |
2232 | builder_.CreateVector(block_map), builder_.CreateVector(fb_dim_metadata)); |
2233 | } |
2234 | |
2235 | std::vector<std::pair<int, int>> Translator::( |
2236 | mlir::Block* block) { |
2237 | std::vector<std::pair<int, int>> control_edges; |
2238 | |
2239 | mlir::IRRewriter rewriter(block->getParentOp()->getContext()); |
2240 | |
2241 | // Since we're modifying *block, we store integer offsets to block->begin(). |
2242 | llvm::DenseMap<Operation*, int> control_nodes_at; |
2243 | std::vector<Operation*> control_nodes; |
2244 | for (const auto& item : llvm::enumerate(*block)) { |
2245 | if (llvm::isa<mlir::TFL::ControlNodeOp>(item.value())) { |
2246 | control_nodes.push_back(&item.value()); |
2247 | control_nodes_at.try_emplace(&item.value(), item.index()); |
2248 | } |
2249 | } |
2250 | |
2251 | for (auto outer_op : control_nodes) { |
2252 | auto control_node_op = dyn_cast<mlir::TFL::ControlNodeOp>(outer_op); |
2253 | auto* inner_op = &control_node_op.body().front().front(); |
2254 | auto control_token = control_node_op.control(); |
2255 | |
2256 | // Now go through all uses. Since *block is in executable order, control |
2257 | // edges always point to operations we haven't modified yet. |
2258 | for (auto& use : control_token.getUses()) { |
2259 | auto owner = use.getOwner(); |
2260 | // Control tokens can only be consumed by other ControlNodeOps, |
2261 | assert(llvm::isa<mlir::TFL::ControlNodeOp>(owner)); |
2262 | assert(control_nodes_at.find(owner) != control_nodes_at.end()); |
2263 | // Control edge in terms of offsets. |
2264 | control_edges.emplace_back(control_nodes_at[outer_op], |
2265 | control_nodes_at[owner]); |
2266 | } |
2267 | control_token.dropAllUses(); |
2268 | |
2269 | // Replace the ControlNodeOp with the wrapped operation. |
2270 | rewriter.setInsertionPointAfter(outer_op); |
2271 | auto* cloned_inner = rewriter.clone(*inner_op); |
2272 | for (auto it : |
2273 | llvm::zip(control_node_op.outputs(), cloned_inner->getResults())) { |
2274 | std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); |
2275 | } |
2276 | rewriter.eraseOp(outer_op); |
2277 | } |
2278 | return control_edges; |
2279 | } |
2280 | |
2281 | } // namespace |
2282 | |
2283 | namespace tflite { |
2284 | |
2285 | bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, |
2286 | const FlatbufferExportOptions& options, |
2287 | std::string* serialized_flatbuffer) { |
2288 | auto maybe_translated = Translator::Translate( |
2289 | module, options.toco_flags, options.saved_model_tags, |
2290 | options.op_or_arg_name_mapper, options.metadata); |
2291 | if (!maybe_translated) return false; |
2292 | *serialized_flatbuffer = std::move(*maybe_translated); |
2293 | return true; |
2294 | } |
2295 | |
2296 | } // namespace tflite |
2297 | |