1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/import_tensorflow.h"
16
17#include <memory>
18#include <string>
19#include <utility>
20#include <vector>
21
22#include "google/protobuf/map.h"
23#include "google/protobuf/text_format.h"
24#include "absl/memory/memory.h"
25#include "absl/strings/match.h"
26#include "absl/strings/numbers.h"
27#include "absl/strings/str_cat.h"
28#include "absl/strings/str_split.h"
29#include "absl/strings/strip.h"
30#include "tensorflow/core/common_runtime/device_factory.h"
31#include "tensorflow/core/common_runtime/function.h"
32#include "tensorflow/core/common_runtime/graph_constructor.h"
33#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
34#include "tensorflow/core/framework/attr_value.pb.h"
35#include "tensorflow/core/framework/function.pb.h"
36#include "tensorflow/core/framework/graph.pb.h"
37#include "tensorflow/core/framework/node_def.pb.h"
38#include "tensorflow/core/framework/tensor.pb.h"
39#include "tensorflow/core/framework/tensor_shape.pb.h"
40#include "tensorflow/core/framework/types.pb.h"
41#include "tensorflow/core/lib/core/errors.h"
42#include "tensorflow/core/lib/core/status.h"
43#include "tensorflow/core/platform/logging.h"
44#include "tensorflow/core/public/session_options.h"
45#include "tensorflow/core/public/version.h"
46#include "tensorflow/lite/toco/model.h"
47#include "tensorflow/lite/toco/model_flags.pb.h"
48#include "tensorflow/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
49#include "tensorflow/lite/toco/tensorflow_util.h"
50#include "tensorflow/lite/toco/tooling_util.h"
51
52using tensorflow::AttrValue;
53using tensorflow::DT_BOOL;
54using tensorflow::DT_COMPLEX64;
55using tensorflow::DT_FLOAT;
56using tensorflow::DT_INT16;
57using tensorflow::DT_INT32;
58using tensorflow::DT_INT64;
59using tensorflow::DT_QUINT8;
60using tensorflow::DT_STRING;
61using tensorflow::DT_UINT16;
62using tensorflow::DT_UINT32;
63using tensorflow::DT_UINT8;
64using tensorflow::GraphDef;
65using tensorflow::NodeDef;
66using tensorflow::TensorProto;
67using tensorflow::TensorShapeProto;
68
69namespace toco {
70
71namespace {
72bool HasAttr(const NodeDef& node, const std::string& attr_name) {
73 return node.attr().count(attr_name) > 0;
74}
75
76bool HasWildcardDimension(const TensorShapeProto& shape) {
77 for (const auto& dim : shape.dim()) {
78 if (dim.size() == -1) return true;
79 }
80 return false;
81}
82
83const std::string& GetStringAttr(const NodeDef& node,
84 const std::string& attr_name) {
85 CHECK(HasAttr(node, attr_name));
86 const auto& attr = node.attr().at(attr_name);
87 CHECK_EQ(attr.value_case(), AttrValue::kS);
88 return attr.s();
89}
90
91int64_t GetIntAttr(const NodeDef& node, const std::string& attr_name) {
92 CHECK(HasAttr(node, attr_name)) << attr_name << " not found in:\n"
93 << node.DebugString();
94 const auto& attr = node.attr().at(attr_name);
95 CHECK_EQ(attr.value_case(), AttrValue::kI);
96 return attr.i();
97}
98
99float GetFloatAttr(const NodeDef& node, const std::string& attr_name) {
100 CHECK(HasAttr(node, attr_name));
101 const auto& attr = node.attr().at(attr_name);
102 CHECK_EQ(attr.value_case(), AttrValue::kF);
103 return attr.f();
104}
105
106bool GetBoolAttr(const NodeDef& node, const std::string& attr_name) {
107 CHECK(HasAttr(node, attr_name));
108 const auto& attr = node.attr().at(attr_name);
109 CHECK_EQ(attr.value_case(), AttrValue::kB);
110 return attr.b();
111}
112
113tensorflow::DataType GetDataTypeAttr(const NodeDef& node,
114 const std::string& attr_name) {
115 CHECK(HasAttr(node, attr_name));
116 const auto& attr = node.attr().at(attr_name);
117 CHECK_EQ(attr.value_case(), AttrValue::kType);
118 return attr.type();
119}
120
121const TensorShapeProto& GetShapeAttr(const NodeDef& node,
122 const std::string& attr_name) {
123 CHECK(HasAttr(node, attr_name));
124 const auto& attr = node.attr().at(attr_name);
125 CHECK_EQ(attr.value_case(), AttrValue::kShape);
126 return attr.shape();
127}
128
129const TensorProto& GetTensorAttr(const NodeDef& node,
130 const std::string& attr_name) {
131 CHECK(HasAttr(node, attr_name)) << "No attr named '" << attr_name << "'";
132 const auto& attr = node.attr().at(attr_name);
133 CHECK_EQ(attr.value_case(), AttrValue::kTensor);
134 return attr.tensor();
135}
136
137const AttrValue::ListValue& GetListAttr(const NodeDef& node,
138 const std::string& attr_name) {
139 CHECK(HasAttr(node, attr_name));
140 const auto& attr = node.attr().at(attr_name);
141 CHECK_EQ(attr.value_case(), AttrValue::kList);
142 return attr.list();
143}
144
145tensorflow::Status CheckOptionalAttr(const NodeDef& node,
146 const std::string& attr_name,
147 const std::string& expected_value) {
148 if (HasAttr(node, attr_name)) {
149 const std::string& value = GetStringAttr(node, attr_name);
150 if (value != expected_value) {
151 return tensorflow::errors::InvalidArgument(
152 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
153 expected_value + "'");
154 }
155 }
156 return ::tensorflow::OkStatus();
157}
158
159tensorflow::Status CheckOptionalAttr(
160 const NodeDef& node, const std::string& attr_name,
161 const tensorflow::DataType& expected_value) {
162 if (HasAttr(node, attr_name)) {
163 const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name);
164 if (value != expected_value) {
165 return tensorflow::errors::InvalidArgument(
166 "Unexpected value for attribute '" + attr_name + "'. Expected '" +
167 tensorflow::DataType_Name(expected_value) + "'");
168 }
169 }
170 return ::tensorflow::OkStatus();
171}
172
173template <typename T1, typename T2>
174tensorflow::Status ExpectValue(const T1& v1, const T2& v2,
175 const std::string& description) {
176 if (v1 == v2) return ::tensorflow::OkStatus();
177 return tensorflow::errors::InvalidArgument(absl::StrCat(
178 "Unexpected ", description, ": got ", v1, ", expected ", v2));
179}
180
181ArrayDataType ConvertDataType(tensorflow::DataType dtype) {
182 if (dtype == DT_UINT8)
183 return ArrayDataType::kUint8;
184 else if (dtype == DT_FLOAT)
185 return ArrayDataType::kFloat;
186 else if (dtype == DT_BOOL)
187 return ArrayDataType::kBool;
188 else if (dtype == DT_INT16)
189 return ArrayDataType::kInt16;
190 else if (dtype == DT_UINT16)
191 return ArrayDataType::kUint16;
192 else if (dtype == DT_INT32)
193 return ArrayDataType::kInt32;
194 else if (dtype == DT_UINT32)
195 return ArrayDataType::kUint32;
196 else if (dtype == DT_INT64)
197 return ArrayDataType::kInt64;
198 else if (dtype == DT_STRING)
199 return ArrayDataType::kString;
200 else if (dtype == DT_COMPLEX64)
201 return ArrayDataType::kComplex64;
202 else
203 LOG(INFO) << "Unsupported data type in placeholder op: " << dtype;
204 return ArrayDataType::kNone;
205}
206
207tensorflow::Status ImportShape(
208 const TFLITE_PROTO_NS::RepeatedPtrField<tensorflow::TensorShapeProto_Dim>&
209 input_dims,
210 int* input_flat_size, Shape* shape) {
211 std::vector<int> input_dims_only_sizes;
212 bool zero_sized_shape = false;
213 for (auto& d : input_dims) {
214 // TensorFlow's shapes use int64s, while TOCO uses ints.
215 if (d.size() > std::numeric_limits<int>::max()) {
216 return tensorflow::errors::InvalidArgument("Shape element overflows");
217 }
218 if (d.size() == 0) {
219 zero_sized_shape = true;
220 }
221 input_dims_only_sizes.push_back(d.size());
222 }
223
224 // Note that up to this point we were OK with the input shape containing
225 // elements valued -1 or 0, which are perfectly legal in tensorflow. However
226 // our CheckValidShapeDimensions() insists on them being >= 1, with the
227 // exception of the "scalar" shape [0]. The main issue with zero-values shape
228 // elements is that the corresponding arrays don't contain any data and the
229 // allocation code gets a bit confused. It seems that the code expects an
230 // empty shape for zero-sized shapes, so we will do just that, except for the
231 // [0] case.
232 // TODO(b/119325030): In order to correctly import the "scalar" shapes the
233 // following test must include "&& input_dims_only_sizes.size() > 1", but
234 // that seems to slow everything down a lot.
235 if (zero_sized_shape) {
236 shape->mutable_dims()->clear();
237 if (input_flat_size != nullptr) *input_flat_size = 0;
238 return ::tensorflow::OkStatus();
239 }
240
241 *shape->mutable_dims() = input_dims_only_sizes;
242
243 if (input_flat_size == nullptr) return ::tensorflow::OkStatus();
244
245 return NumElements(input_dims_only_sizes, input_flat_size);
246}
247
248// Define ways to retrieve data from tensors of different types.
249// TODO(b/80208043): simply use tensorflow::Tensor::FromProto() instead.
250template <typename T>
251struct TensorTraits;
252
253template <>
254struct TensorTraits<float> {
255 static int size(const TensorProto& p) { return p.float_val_size(); }
256 static float get(const TensorProto& p, int i) { return p.float_val(i); }
257 static std::string accessor_name() { return "float_val"; }
258 static std::string type_name() { return "float"; }
259 static void CopyFromContent(const TensorProto& p, std::vector<float>* data) {
260 toco::port::CopyToBuffer(p.tensor_content(),
261 reinterpret_cast<char*>(data->data()));
262 }
263};
264
265template <>
266struct TensorTraits<uint8_t> {
267 static int size(const TensorProto& p) { return p.int_val_size(); }
268 static uint8_t get(const TensorProto& p, int i) { return p.int_val(i); }
269 static std::string accessor_name() { return "int_val"; }
270 static std::string type_name() { return "uint8"; }
271 static void CopyFromContent(const TensorProto& p,
272 std::vector<uint8_t>* data) {
273 toco::port::CopyToBuffer(p.tensor_content(),
274 reinterpret_cast<char*>(data->data()));
275 }
276};
277
278template <>
279struct TensorTraits<std::complex<float>> {
280 static int size(const TensorProto& p) { return p.scomplex_val_size() / 2; }
281 static std::complex<float> get(const TensorProto& p, int i) {
282 return std::complex<float>(p.scomplex_val(2 * i),
283 p.scomplex_val(2 * i + 1));
284 }
285 static std::string accessor_name() { return "scomplex_val"; }
286 static std::string type_name() { return "complex64"; }
287 static void CopyFromContent(const TensorProto& p,
288 std::vector<std::complex<float>>* data) {
289 toco::port::CopyToBuffer(p.tensor_content(),
290 reinterpret_cast<char*>(data->data()));
291 }
292};
293
294template <>
295struct TensorTraits<int32> {
296 static int size(const TensorProto& p) { return p.int_val_size(); }
297 static int32 get(const TensorProto& p, int i) { return p.int_val(i); }
298 static std::string accessor_name() { return "int_val"; }
299 static std::string type_name() { return "int32"; }
300 static void CopyFromContent(const TensorProto& p, std::vector<int32>* data) {
301 toco::port::CopyToBuffer(p.tensor_content(),
302 reinterpret_cast<char*>(data->data()));
303 }
304};
305
306template <>
307struct TensorTraits<uint32> {
308 static int size(const TensorProto& p) { return p.uint32_val_size(); }
309 static int32 get(const TensorProto& p, int i) { return p.uint32_val(i); }
310 static std::string accessor_name() { return "uint32_val"; }
311 static std::string type_name() { return "uint32"; }
312 static void CopyFromContent(const TensorProto& p, std::vector<uint32>* data) {
313 toco::port::CopyToBuffer(p.tensor_content(),
314 reinterpret_cast<char*>(data->data()));
315 }
316};
317
318template <>
319struct TensorTraits<int64_t> {
320 static int size(const TensorProto& p) { return p.int64_val_size(); }
321 static int64_t get(const TensorProto& p, int i) { return p.int64_val(i); }
322 static std::string accessor_name() { return "int64_val"; }
323 static std::string type_name() { return "int64"; }
324 static void CopyFromContent(const TensorProto& p,
325 std::vector<int64_t>* data) {
326 toco::port::CopyToBuffer(p.tensor_content(),
327 reinterpret_cast<char*>(data->data()));
328 }
329};
330
331template <>
332struct TensorTraits<bool> {
333 static int size(const TensorProto& p) { return p.bool_val_size(); }
334 static bool get(const TensorProto& p, int i) { return p.bool_val(i); }
335 static std::string accessor_name() { return "bool_val"; }
336 static std::string type_name() { return "bool"; }
337 static void CopyFromContent(const TensorProto& p, std::vector<bool>* data) {
338 std::vector<char> buf(p.tensor_content().size());
339 toco::port::CopyToBuffer(p.tensor_content(), buf.data());
340 for (int i = 0; i < p.tensor_content().size(); i++) {
341 (*data)[i] = static_cast<bool>(buf[i]);
342 }
343 }
344};
345
346template <typename T>
347tensorflow::Status ImportTensorData(const TensorProto& input_tensor,
348 int input_flat_size,
349 std::vector<T>* output_data) {
350 CHECK_GE(output_data->size(), input_flat_size);
351 int num_elements_in_tensor = TensorTraits<T>::size(input_tensor);
352 if (num_elements_in_tensor == input_flat_size) {
353 for (int i = 0; i < num_elements_in_tensor; i++) {
354 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
355 }
356 } else if (input_tensor.tensor_content().size() ==
357 input_flat_size * sizeof(T)) {
358 TensorTraits<T>::CopyFromContent(input_tensor, output_data);
359 } else if (num_elements_in_tensor >= 0 &&
360 num_elements_in_tensor < input_flat_size) {
361 // TODO(b/80208043): use tensorflow::Tensor::FromProto() which is the
362 // official way to import tensor data. This particular else-if handles a
363 // grappler optimization where the last few elements in a tensor are
364 // omitted if they are repeated, and where all elements are omitted if they
365 // are zero.
366 int i = 0;
367 for (; i < num_elements_in_tensor; ++i) {
368 (*output_data)[i] = TensorTraits<T>::get(input_tensor, i);
369 }
370 auto last = i == 0 ? T(0) : (*output_data)[i - 1];
371 for (; i < input_flat_size; ++i) {
372 (*output_data)[i] = last;
373 }
374 } else {
375 std::string accessor_name = TensorTraits<T>::accessor_name();
376 std::string type_name = TensorTraits<T>::type_name();
377 return tensorflow::errors::InvalidArgument(
378 absl::StrCat("Neither input_content (",
379 input_tensor.tensor_content().size() / sizeof(T), ") nor ",
380 accessor_name, " (", num_elements_in_tensor,
381 ") have the right dimensions (", input_flat_size,
382 ") for this ", type_name, " tensor"));
383 }
384 return ::tensorflow::OkStatus();
385}
386
387tensorflow::Status ImportFloatArray(const TensorProto& input_tensor,
388 Array* output_array) {
389 CHECK_EQ(input_tensor.dtype(), DT_FLOAT);
390 const auto& input_shape = input_tensor.tensor_shape();
391 CHECK_LE(input_shape.dim_size(), 6);
392 int input_flat_size;
393 auto status = ImportShape(input_shape.dim(), &input_flat_size,
394 output_array->mutable_shape());
395 if (!status.ok()) return status;
396
397 auto& output_float_data =
398 output_array->GetMutableBuffer<ArrayDataType::kFloat>().data;
399 output_float_data.resize(RequiredBufferSizeForShape(output_array->shape()),
400 0.f);
401 return ImportTensorData<float>(input_tensor, input_flat_size,
402 &output_float_data);
403}
404
405tensorflow::Status ImportComplex64Array(const TensorProto& input_tensor,
406 Array* output_array) {
407 CHECK_EQ(input_tensor.dtype(), DT_COMPLEX64);
408 const auto& input_shape = input_tensor.tensor_shape();
409 CHECK_LE(input_shape.dim_size(), 4);
410 int input_flat_size;
411 auto status = ImportShape(input_shape.dim(), &input_flat_size,
412 output_array->mutable_shape());
413 if (!status.ok()) return status;
414
415 auto& output_complex_data =
416 output_array->GetMutableBuffer<ArrayDataType::kComplex64>().data;
417 output_complex_data.resize(RequiredBufferSizeForShape(output_array->shape()),
418 std::complex<float>(0.f, 0.f));
419 return ImportTensorData<std::complex<float>>(input_tensor, input_flat_size,
420 &output_complex_data);
421}
422
423tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor,
424 Array* output_array) {
425 CHECK_EQ(input_tensor.dtype(), DT_QUINT8);
426 const auto& input_shape = input_tensor.tensor_shape();
427 CHECK_LE(input_shape.dim_size(), 6);
428 int input_flat_size;
429 auto status = ImportShape(input_shape.dim(), &input_flat_size,
430 output_array->mutable_shape());
431 if (!status.ok()) return status;
432
433 auto& output_int_data =
434 output_array->GetMutableBuffer<ArrayDataType::kUint8>().data;
435 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
436 return ImportTensorData<uint8_t>(input_tensor, input_flat_size,
437 &output_int_data);
438}
439
440tensorflow::Status ImportInt32Array(const TensorProto& input_tensor,
441 Array* output_array) {
442 CHECK_EQ(input_tensor.dtype(), DT_INT32);
443 const auto& input_shape = input_tensor.tensor_shape();
444 CHECK_LE(input_shape.dim_size(), 6);
445 int input_flat_size;
446 auto status = ImportShape(input_shape.dim(), &input_flat_size,
447 output_array->mutable_shape());
448 if (!status.ok()) return status;
449
450 auto& output_int_data =
451 output_array->GetMutableBuffer<ArrayDataType::kInt32>().data;
452 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
453 return ImportTensorData<int32>(input_tensor, input_flat_size,
454 &output_int_data);
455}
456
457tensorflow::Status ImportUint32Array(const TensorProto& input_tensor,
458 Array* output_array) {
459 CHECK_EQ(input_tensor.dtype(), DT_UINT32);
460 const auto& input_shape = input_tensor.tensor_shape();
461 CHECK_LE(input_shape.dim_size(), 6);
462 int input_flat_size;
463 auto status = ImportShape(input_shape.dim(), &input_flat_size,
464 output_array->mutable_shape());
465 if (!status.ok()) return status;
466
467 auto& output_int_data =
468 output_array->GetMutableBuffer<ArrayDataType::kUint32>().data;
469 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
470 return ImportTensorData<uint32>(input_tensor, input_flat_size,
471 &output_int_data);
472}
473
474tensorflow::Status ImportInt64Array(const TensorProto& input_tensor,
475 Array* output_array) {
476 CHECK_EQ(input_tensor.dtype(), DT_INT64);
477 const auto& input_shape = input_tensor.tensor_shape();
478 CHECK_LE(input_shape.dim_size(), 6);
479 int input_flat_size;
480 auto status = ImportShape(input_shape.dim(), &input_flat_size,
481 output_array->mutable_shape());
482 if (!status.ok()) return status;
483
484 auto& output_int_data =
485 output_array->GetMutableBuffer<ArrayDataType::kInt64>().data;
486 output_int_data.resize(RequiredBufferSizeForShape(output_array->shape()), 0);
487 return ImportTensorData<int64_t>(input_tensor, input_flat_size,
488 &output_int_data);
489}
490
491tensorflow::Status ImportBoolArray(const TensorProto& input_tensor,
492 Array* output_array) {
493 CHECK_EQ(input_tensor.dtype(), DT_BOOL);
494 const auto& input_shape = input_tensor.tensor_shape();
495 CHECK_LE(input_shape.dim_size(), 6);
496 int input_flat_size;
497 auto status = ImportShape(input_shape.dim(), &input_flat_size,
498 output_array->mutable_shape());
499 if (!status.ok()) return status;
500
501 auto& output_bool_data =
502 output_array->GetMutableBuffer<ArrayDataType::kBool>().data;
503 output_bool_data.resize(RequiredBufferSizeForShape(output_array->shape()),
504 false);
505 status =
506 ImportTensorData<bool>(input_tensor, input_flat_size, &output_bool_data);
507 if (!status.ok() && output_bool_data.size() == 1) {
508 // Some graphs have bool const nodes without actual value...
509 // assuming that 'false' is implied.
510 // So far only encountered that in an array with 1 entry, let's
511 // require that until we encounter a graph where that's not the case.
512 output_bool_data[0] = false;
513 return ::tensorflow::OkStatus();
514 }
515 return status;
516}
517
518tensorflow::Status ImportStringArray(const TensorProto& input_tensor,
519 Array* output_array) {
520 CHECK_EQ(input_tensor.dtype(), DT_STRING);
521 const auto& input_shape = input_tensor.tensor_shape();
522 CHECK_LE(input_shape.dim_size(), 6);
523 int input_flat_size;
524 auto status = ImportShape(input_shape.dim(), &input_flat_size,
525 output_array->mutable_shape());
526 if (!status.ok()) return status;
527
528 if (input_flat_size != input_tensor.string_val_size()) {
529 return tensorflow::errors::InvalidArgument(
530 "Input_content string_val doesn't have the right dimensions "
531 "for this string tensor");
532 }
533
534 auto& output_string_data =
535 output_array->GetMutableBuffer<ArrayDataType::kString>().data;
536 output_string_data.resize(RequiredBufferSizeForShape(output_array->shape()));
537 CHECK_GE(output_string_data.size(), input_flat_size);
538 for (int i = 0; i < input_flat_size; ++i) {
539 output_string_data[i] = input_tensor.string_val(i);
540 }
541 return ::tensorflow::OkStatus();
542}
543
544// Count the number of inputs of a given node. If
545// `tf_import_flags.drop_control_dependency` is true, count the number of
546// non-control-dependency inputs.
547int GetInputsCount(const NodeDef& node,
548 const TensorFlowImportFlags& tf_import_flags) {
549 if (tf_import_flags.drop_control_dependency) {
550 for (size_t i = 0; i < node.input_size(); ++i) {
551 if (node.input(i)[0] == '^') {
552 return i;
553 }
554 }
555 }
556 return node.input_size();
557}
558
559tensorflow::Status CheckInputsCount(
560 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
561 int expected_input_count) {
562 if (GetInputsCount(node, tf_import_flags) != expected_input_count) {
563 return tensorflow::errors::FailedPrecondition(
564 node.op(), " node expects ", expected_input_count,
565 " input(s) other than control dependencies: ", node.DebugString());
566 }
567 return ::tensorflow::OkStatus();
568}
569
570template <ArrayDataType T>
571std::string CreateConstArray(
572 Model* model, std::string const& name,
573 std::vector<typename toco::DataType<T>> const& data) {
574 // Utility function to create a const 1D array, useful for input parameters.
575 std::string array_name = toco::AvailableArrayName(*model, name);
576 auto& array = model->GetOrCreateArray(array_name);
577 array.data_type = T;
578 array.mutable_shape()->mutable_dims()->emplace_back(
579 static_cast<int>(data.size()));
580 array.GetMutableBuffer<T>().data = data;
581 return array_name;
582}
583
584// Retain TensorFlow NodeDef in Toco Operator.
585//
586// If an op is supported by Toco but not supported by TFLite, TFLite exporter
587// will use the retained NodeDef to populate a Flex op when Flex mode is
588// enabled.
589//
590// This can't be easily applied to all operations, because a TensorFlow node
591// may become multiple Toco operators. Thus we need to call this function in
592// operator conversion functions one by one whenever feasible.
593//
594// This may cause problems if a graph transformation rule changes parameters
595// of the node. When calling this function, please check if any existing
596// graph transformation rule will change an existing operator with the same
597// type.
598//
599// This provides a route to handle Toco-supported & TFLite-unsupported ops
600// in Flex mode. However it's not a solid solution. Eventually we should
601// get rid of this.
602// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
603// this function.
604void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
605 node.SerializeToString(&op->tensorflow_node_def);
606}
607
608void GetOutputNamesFromNodeDef(const NodeDef& node,
609 const tensorflow::OpDef& op_def,
610 TensorFlowUnsupportedOperator* op) {
611 int next_output = 0;
612 auto add_output = [&node, &next_output, op]() {
613 if (next_output == 0) {
614 op->outputs.push_back(node.name()); // Implicit :0.
615 } else {
616 op->outputs.push_back(absl::StrCat(node.name(), ":", next_output));
617 }
618 ++next_output;
619 };
620 for (int i = 0; i < op_def.output_arg_size(); ++i) {
621 std::string multiples = op_def.output_arg(i).number_attr();
622 if (!multiples.empty()) {
623 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
624 int num_outputs = GetIntAttr(node, multiples);
625 for (int j = 0; j < num_outputs; ++j) {
626 add_output();
627 }
628 } else {
629 std::string list = op_def.output_arg(i).type_list_attr();
630 if (!list.empty()) {
631 CHECK(HasAttr(node, list)) << "No attr named " << list;
632 const AttrValue::ListValue& list_value = GetListAttr(node, list);
633 for (int j = 0; j < list_value.type_size(); ++j) {
634 add_output();
635 }
636 } else {
637 add_output();
638 }
639 }
640 }
641}
642
643void GetOutputTypesFromNodeDef(const NodeDef& node,
644 const tensorflow::OpDef& op_def,
645 TensorFlowUnsupportedOperator* op) {
646 // The given type to the op, or clear the types if invalid.
647 auto add_type = [&node, op](tensorflow::DataType type) {
648 if (type == tensorflow::DT_INVALID) {
649 LOG(WARNING) << "Op node missing output type attribute: " << node.name();
650 op->output_data_types.clear();
651 } else {
652 op->output_data_types.push_back(ConvertDataType(type));
653 }
654 };
655
656 // Retrieve the data type according to the OpDef definition: either the
657 // "type" or "type_attr" field will be set.
658 auto get_type = [&node](const tensorflow::OpDef::ArgDef& a) {
659 if (a.type() != tensorflow::DT_INVALID) {
660 return a.type();
661 } else if (HasAttr(node, a.type_attr())) {
662 return GetDataTypeAttr(node, a.type_attr());
663 } else {
664 return tensorflow::DT_INVALID;
665 }
666 };
667
668 for (int i = 0; i < op_def.output_arg_size(); ++i) {
669 std::string multiples = op_def.output_arg(i).number_attr();
670 if (!multiples.empty()) {
671 CHECK(HasAttr(node, multiples)) << "No attr named " << multiples;
672 int num_outputs = GetIntAttr(node, multiples);
673 auto type = get_type(op_def.output_arg(i));
674 for (int j = 0; j < num_outputs; ++j) {
675 add_type(type);
676 }
677 } else {
678 std::string list = op_def.output_arg(i).type_list_attr();
679 if (!list.empty()) {
680 CHECK(HasAttr(node, list)) << "No attr named " << list;
681 const AttrValue::ListValue& list_value = GetListAttr(node, list);
682 for (int j = 0; j < list_value.type_size(); ++j) {
683 add_type(list_value.type(j));
684 }
685 } else {
686 add_type(get_type(op_def.output_arg(i)));
687 }
688 }
689 }
690}
691
692tensorflow::Status ConvertUnsupportedOperator(
693 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
694 const ModelFlags& model_flags, Model* model) {
695 // Names of special attributes in TF graph that are used by Toco.
696 static constexpr char kAttrOutputQuantized[] = "_output_quantized";
697 static constexpr char kAttrOutputTypes[] = "_output_types";
698 static constexpr char kAttrOutputShapes[] = "_output_shapes";
699 static constexpr char kAttrSupportOutputTypeFloatInQuantizedOp[] =
700 "_support_output_type_float_in_quantized_op";
701
702 LOG(INFO) << "Converting unsupported operation: " << node.op();
703
704 auto* op = new TensorFlowUnsupportedOperator;
705 op->tensorflow_op = node.op();
706
707 // For Flex mode. Please read the comments of the function.
708 RetainTensorFlowNodeDef(node, op);
709
710 model->operators.emplace_back(op);
711
712 // Parse inputs.
713 const int num_inputs = GetInputsCount(node, tf_import_flags);
714 for (int i = 0; i < num_inputs; ++i) {
715 op->inputs.push_back(node.input(i));
716 }
717
718 // Parse outputs. Name them after the node's name, plus an ordinal suffix.
719 // Note that some outputs are to be multiplied by a named attribute.
720 const tensorflow::OpDef* op_def = nullptr;
721 if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
722 GetOutputNamesFromNodeDef(node, *op_def, op);
723 } else {
724 op->outputs.push_back(node.name()); // Implicit :0.
725 }
726
727 // Parse if the op supports quantization
728 if (HasAttr(node, kAttrOutputQuantized)) {
729 op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
730 }
731 // Parse if the quantized op allows output arrays of type float
732 if (HasAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp)) {
733 op->support_output_type_float_in_quantized_op =
734 GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
735 }
736
737 // Parse output type(s).
738 if (HasAttr(node, kAttrOutputTypes)) {
739 const auto& output_types = GetListAttr(node, kAttrOutputTypes);
740 for (int i = 0; i < output_types.type_size(); ++i) {
741 op->output_data_types.push_back(ConvertDataType(output_types.type(i)));
742 }
743 } else if (HasAttr(node, "Tout")) {
744 const auto& output_type = GetDataTypeAttr(node, "Tout");
745 op->output_data_types.push_back(ConvertDataType(output_type));
746 } else if (op_def != nullptr) {
747 GetOutputTypesFromNodeDef(node, *op_def, op);
748 } else {
749 // TODO(b/113613439): Figure out how to propagate types for custom ops
750 // that have no OpDef.
751 LOG(INFO) << "Unable to determine output type for op: " << node.op();
752 }
753
754 // Parse output shape(s).
755 if (HasAttr(node, kAttrOutputShapes)) {
756 const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
757 Shape output_shape;
758 for (int i = 0; i < output_shapes.shape_size(); ++i) {
759 const auto& shape = output_shapes.shape(i);
760 // TOCO doesn't yet properly handle shapes with wildcard dimensions.
761 // TODO(b/113613439): Handle shape inference for unsupported ops that have
762 // shapes with wildcard dimensions.
763 if (HasWildcardDimension(shape)) {
764 LOG(INFO) << "Skipping wildcard output shape(s) for node: "
765 << node.name();
766 op->output_shapes.clear();
767 break;
768 }
769 const auto status =
770 ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
771 if (!status.ok()) {
772 return status;
773 }
774 op->output_shapes.push_back(output_shape);
775 }
776 }
777 return ::tensorflow::OkStatus();
778}
779
780tensorflow::Status ConvertConstOperator(
781 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
782 const ModelFlags& model_flags, Model* model) {
783 CHECK_EQ(node.op(), "Const");
784 const auto& tensor = GetTensorAttr(node, "value");
785 const auto dtype = GetDataTypeAttr(node, "dtype");
786
787 tensorflow::Status status = ::tensorflow::OkStatus();
788
789 auto& array = model->GetOrCreateArray(node.name());
790 switch (dtype) {
791 case DT_FLOAT:
792 array.data_type = ArrayDataType::kFloat;
793 status = ImportFloatArray(tensor, &array);
794 break;
795 case DT_INT32:
796 array.data_type = ArrayDataType::kInt32;
797 status = ImportInt32Array(tensor, &array);
798 break;
799 case DT_UINT32:
800 array.data_type = ArrayDataType::kUint32;
801 status = ImportUint32Array(tensor, &array);
802 break;
803 case DT_QUINT8:
804 array.data_type = ArrayDataType::kUint8;
805 status = ImportQuint8Array(tensor, &array);
806 break;
807 case DT_INT64:
808 array.data_type = ArrayDataType::kInt64;
809 status = ImportInt64Array(tensor, &array);
810 break;
811 case DT_STRING:
812 array.data_type = ArrayDataType::kString;
813 status = ImportStringArray(tensor, &array);
814 break;
815 case DT_BOOL:
816 array.data_type = ArrayDataType::kBool;
817 status = ImportBoolArray(tensor, &array);
818 break;
819 case DT_COMPLEX64:
820 array.data_type = ArrayDataType::kComplex64;
821 status = ImportComplex64Array(tensor, &array);
822 break;
823 default:
824 array.data_type = ArrayDataType::kNone;
825 // do nothing, silently ignore the Const data.
826 // We just make a dummy buffer to indicate that
827 // this array does not rely on external input.
828 array.GetMutableBuffer<ArrayDataType::kNone>();
829 break;
830 }
831 TF_RETURN_WITH_CONTEXT_IF_ERROR(
832 status, " (while processing node '" + node.name() + "')");
833 return ::tensorflow::OkStatus();
834}
835
836tensorflow::Status ConvertConvOperator(
837 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
838 const ModelFlags& model_flags, Model* model) {
839 CHECK_EQ(node.op(), "Conv2D");
840 TF_RETURN_IF_ERROR(CheckInputsCount(node, tf_import_flags, 2));
841
842 // We only support NHWC, which is the default data_format.
843 // So if data_format is not defined, we're all good.
844 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC"));
845 TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT));
846
847 const auto& input_name = node.input(0);
848 const auto& weights_name = node.input(1);
849 const auto& reordered_weights_name =
850 AvailableArrayName(*model, weights_name + "_reordered");
851 // Check if a ReorderAxesOperator was already created for these weights
852 // (that happens when multiple layers share the same weights).
853 const Operator* existing_reorder =
854 GetOpWithOutput(*model, reordered_weights_name);
855 if (existing_reorder) {
856 // Check that it is safe to rely on the _reordered naming of the output
857 // array!
858 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
859 } else {
860 // Create a new ReorderAxesOperator
861 auto* reorder = new ReorderAxesOperator;
862 reorder->inputs = {weights_name};
863 reorder->outputs = {reordered_weights_name};
864 reorder->input_axes_order = AxesOrder::kHWIO;
865 reorder->output_axes_order = AxesOrder::kOHWI;
866 model->operators.emplace_back(reorder);
867 }
868 if (!HasAttr(node, "strides")) {
869 return tensorflow::errors::InvalidArgument("Missing attribute 'strides'");
870 }
871 const auto& strides = GetListAttr(node, "strides");
872 TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
873 TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
874 TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
875 int dilation_height_factor;
876 int dilation_width_factor;
877 if (HasAttr(node, "dilations")) {
878 const auto& dilations = GetListAttr(node, "dilations");
879 TF_RETURN_IF_ERROR(
880 ExpectValue(dilations.i_size(), 4, "number of dilations"));
881 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
882 return tensorflow::errors::InvalidArgument(absl::StrCat(
883 "Can only import Conv ops with dilation along the height "
884 "(1st) or width (2nd) axis. TensorFlow op \"",
885 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
886 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
887 }
888 dilation_height_factor = dilations.i(1);
889 dilation_width_factor = dilations.i(2);
890 } else {
891 dilation_height_factor = 1;
892 dilation_width_factor = 1;
893 }
894 const auto& padding = GetStringAttr(node, "padding");
895 PaddingType padding_type;
896 if (padding == "SAME") {
897 padding_type = PaddingType::kSame;
898 } else if (padding == "VALID") {
899 padding_type = PaddingType::kValid;
900 } else {
901 return tensorflow::errors::InvalidArgument(
902 "Bad padding (only SAME and VALID are supported)");
903 }
904 auto* conv = new ConvOperator;
905 conv->inputs = {input_name, reordered_weights_name};
906 conv->outputs = {node.name()};
907 conv->stride_height = strides.i(1);
908 conv->stride_width = strides.i(2);
909 conv->dilation_height_factor = dilation_height_factor;
910 conv->dilation_width_factor = dilation_width_factor;
911 conv->padding.type = padding_type;
912 model->operators.emplace_back(conv);
913
914 return ::tensorflow::OkStatus();
915}
916
917tensorflow::Status ConvertDepthwiseConvOperator(
918 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
919 const ModelFlags& model_flags, Model* model) {
920 CHECK_EQ(node.op(), "DepthwiseConv2dNative");
921 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
922
923 // We only support NHWC, which is the default data_format.
924 // So if data_format is not defined, we're all good.
925 if (HasAttr(node, "data_format")) {
926 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
927 }
928 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
929
930 const auto& input_name = node.input(0);
931 const auto& weights_name = node.input(1);
932 const auto& reordered_weights_name = weights_name + "_reordered";
933 // Check if a ReorderAxesOperator was already created for these weights
934 // (that happens when multiple layers share the same weights).
935 const Operator* existing_reorder =
936 GetOpWithOutput(*model, reordered_weights_name);
937 if (existing_reorder) {
938 // Check that it is safe to rely on the _reordered naming of the output
939 // array!
940 CHECK(existing_reorder->type == OperatorType::kReorderAxes);
941 } else {
942 // Create a new ReorderAxesOperator
943 auto* reorder = new ReorderAxesOperator;
944 reorder->inputs = {weights_name};
945 reorder->outputs = {reordered_weights_name};
946 reorder->input_axes_order = AxesOrder::kHWIM;
947 reorder->output_axes_order = AxesOrder::k1HWO;
948 model->operators.emplace_back(reorder);
949 }
950 const auto& strides = GetListAttr(node, "strides");
951 TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides"));
952 TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)"));
953 TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)"));
954 int dilation_height_factor;
955 int dilation_width_factor;
956 if (HasAttr(node, "dilations")) {
957 const auto& dilations = GetListAttr(node, "dilations");
958 TF_RETURN_IF_ERROR(
959 ExpectValue(dilations.i_size(), 4, "number of dilations"));
960 if (dilations.i(0) != 1 || dilations.i(3) != 1) {
961 return tensorflow::errors::InvalidArgument(absl::StrCat(
962 "Can only import Conv ops with dilation along the height "
963 "(1st) or width (2nd) axis. TensorFlow op \"",
964 node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
965 dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
966 }
967 dilation_height_factor = dilations.i(1);
968 dilation_width_factor = dilations.i(2);
969 } else {
970 dilation_height_factor = 1;
971 dilation_width_factor = 1;
972 }
973 const auto& padding = GetStringAttr(node, "padding");
974 PaddingType padding_type;
975 if (padding == "SAME") {
976 padding_type = PaddingType::kSame;
977 } else if (padding == "VALID") {
978 padding_type = PaddingType::kValid;
979 } else {
980 return tensorflow::errors::InvalidArgument(
981 "Bad padding (only SAME and VALID are supported)");
982 }
983 auto* conv = new DepthwiseConvOperator;
984 conv->inputs = {input_name, reordered_weights_name};
985 conv->outputs = {node.name()};
986 conv->stride_height = strides.i(1);
987 conv->stride_width = strides.i(2);
988 conv->dilation_height_factor = dilation_height_factor;
989 conv->dilation_width_factor = dilation_width_factor;
990 conv->padding.type = padding_type;
991 model->operators.emplace_back(conv);
992 return ::tensorflow::OkStatus();
993}
994
995tensorflow::Status ConvertDepthToSpaceOperator(
996 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
997 const ModelFlags& model_flags, Model* model) {
998 CHECK_EQ(node.op(), "DepthToSpace");
999 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1000
1001 tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1002 if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1003 dtype != DT_INT64) {
1004 const auto* enum_descriptor = tensorflow::DataType_descriptor();
1005 LOG(FATAL) << "TFLite does not support DepthToSpace with type T:"
1006 << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1007 << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1008 }
1009 auto* op = new DepthToSpaceOperator;
1010 op->inputs.push_back(node.input(0));
1011 op->outputs.push_back(node.name());
1012 op->block_size = GetIntAttr(node, "block_size");
1013 QCHECK_GE(op->block_size, 2);
1014 model->operators.emplace_back(op);
1015 return ::tensorflow::OkStatus();
1016}
1017
1018tensorflow::Status ConvertSpaceToDepthOperator(
1019 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1020 const ModelFlags& model_flags, Model* model) {
1021 CHECK_EQ(node.op(), "SpaceToDepth");
1022 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1023
1024 tensorflow::DataType dtype = GetDataTypeAttr(node, "T");
1025 if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 &&
1026 dtype != DT_INT64) {
1027 const auto* enum_descriptor = tensorflow::DataType_descriptor();
1028 LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:"
1029 << enum_descriptor->FindValueByNumber(dtype)->name() << ". "
1030 << "T must be one of {DT_FLOAT, DT_UINT8, DT_INT32, DT_INT64}.";
1031 }
1032 auto* op = new SpaceToDepthOperator;
1033 op->inputs.push_back(node.input(0));
1034 op->outputs.push_back(node.name());
1035 op->block_size = GetIntAttr(node, "block_size");
1036 QCHECK_GE(op->block_size, 2);
1037 model->operators.emplace_back(op);
1038 return ::tensorflow::OkStatus();
1039}
1040
1041tensorflow::Status ConvertBiasAddOperator(
1042 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1043 const ModelFlags& model_flags, Model* model) {
1044 CHECK_EQ(node.op(), "BiasAdd");
1045 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1046
1047 const auto& input_name = node.input(0);
1048 const auto& bias_name = node.input(1);
1049 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1050 auto* biasadd = new AddOperator;
1051 biasadd->inputs.push_back(input_name);
1052 biasadd->inputs.push_back(bias_name);
1053 biasadd->outputs.push_back(node.name());
1054 model->operators.emplace_back(biasadd);
1055 return ::tensorflow::OkStatus();
1056}
1057
1058tensorflow::Status ConvertRandomUniform(
1059 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1060 const ModelFlags& model_flags, Model* model) {
1061 CHECK_EQ(node.op(), "RandomUniform");
1062 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1063
1064 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_INT32);
1065 auto op = std::make_unique<RandomUniformOperator>();
1066 op->inputs.push_back(node.input(0));
1067 op->outputs.push_back(node.name());
1068 op->dtype = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1069 op->seed = GetIntAttr(node, "seed");
1070 op->seed2 = GetIntAttr(node, "seed2");
1071 CHECK(model != nullptr);
1072 model->operators.emplace_back(std::move(op));
1073 return ::tensorflow::OkStatus();
1074}
1075
1076tensorflow::Status ConvertIdentityOperator(
1077 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1078 const ModelFlags& model_flags, Model* model) {
1079 CHECK(node.op() == "Identity" || node.op() == "CheckNumerics" ||
1080 node.op() == "PlaceholderWithDefault" || node.op() == "StopGradient" ||
1081 node.op() == "Snapshot" || node.op() == "EnsureShape");
1082 auto* op = new TensorFlowIdentityOperator;
1083 // Amazingly, some TensorFlow graphs (at least rajeev_lstm.pb) have
1084 // identity nodes with multiple inputs, but the other inputs seem
1085 // to be gratuitous (in the case of rajeev_lstm.pb, these are
1086 // enumerating the LSTM state arrays). We will just ignore extra
1087 // inputs beyond the first input.
1088 QCHECK_GE(node.input_size(), 1)
1089 << node.op()
1090 << " node expects at least 1 input other than control dependencies: "
1091 << node.DebugString();
1092 const auto& input_name = node.input(0);
1093 op->inputs.push_back(input_name);
1094 op->outputs.push_back(node.name());
1095 model->operators.emplace_back(op);
1096 return ::tensorflow::OkStatus();
1097}
1098
1099tensorflow::Status ConvertIdentityNOperator(
1100 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1101 const ModelFlags& model_flags, Model* model) {
1102 CHECK_EQ(node.op(), "IdentityN");
1103 for (int i = 0; i < node.input_size(); ++i) {
1104 auto* op = new TensorFlowIdentityOperator;
1105 const auto& input_name = node.input(i);
1106 std::string output_name = node.name();
1107 if (i > 0) {
1108 output_name = output_name + ":" + std::to_string(i);
1109 }
1110 op->inputs.push_back(input_name);
1111 op->outputs.push_back(output_name);
1112 model->operators.emplace_back(op);
1113 }
1114 return ::tensorflow::OkStatus();
1115}
1116
1117tensorflow::Status ConvertFakeQuantWithMinMaxArgs(
1118 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1119 const ModelFlags& model_flags, Model* model) {
1120 CHECK_EQ(node.op(), "FakeQuantWithMinMaxArgs");
1121 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1122 auto* op = new FakeQuantOperator;
1123 op->inputs.push_back(node.input(0));
1124 op->minmax = std::make_unique<MinMax>();
1125 auto& minmax = *op->minmax;
1126 minmax.min = GetFloatAttr(node, "min");
1127 minmax.max = GetFloatAttr(node, "max");
1128 op->outputs.push_back(node.name());
1129 // tf.fake_quant_with_min_max_args num_bits defaults to 8.
1130 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1131 if (HasAttr(node, "narrow_range")) {
1132 op->narrow_range = GetBoolAttr(node, "narrow_range");
1133 }
1134 model->operators.emplace_back(op);
1135 return ::tensorflow::OkStatus();
1136}
1137
1138tensorflow::Status ConvertFakeQuantWithMinMaxVars(
1139 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1140 const ModelFlags& model_flags, Model* model) {
1141 CHECK_EQ(node.op(), "FakeQuantWithMinMaxVars");
1142 const int num_inputs = GetInputsCount(node, tf_import_flags);
1143 QCHECK(num_inputs == 3 || num_inputs == 4)
1144 << "FakeQuantWithMinMaxVars node expects 3 or 4 inputs other than "
1145 "control dependencies: "
1146 << node.DebugString();
1147 auto* op = new FakeQuantOperator;
1148 for (int i = 0; i < 3; i++) {
1149 op->inputs.push_back(node.input(i));
1150 }
1151 op->outputs.push_back(node.name());
1152 op->num_bits = HasAttr(node, "num_bits") ? GetIntAttr(node, "num_bits") : 8;
1153 if (HasAttr(node, "narrow_range")) {
1154 op->narrow_range = GetBoolAttr(node, "narrow_range");
1155 }
1156 model->operators.emplace_back(op);
1157 return ::tensorflow::OkStatus();
1158}
1159
1160tensorflow::Status ConvertSqueezeOperator(
1161 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1162 const ModelFlags& model_flags, Model* model) {
1163 CHECK_EQ(node.op(), "Squeeze");
1164 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1165 auto* op = new SqueezeOperator;
1166 op->inputs.push_back(node.input(0));
1167 op->outputs.push_back(node.name());
1168
1169 // When omitted we are to squeeze all dimensions == 1.
1170 if (HasAttr(node, "squeeze_dims")) {
1171 const auto& squeeze_dims = GetListAttr(node, "squeeze_dims");
1172 for (int i = 0; i < squeeze_dims.i_size(); ++i) {
1173 op->squeeze_dims.push_back(squeeze_dims.i(i));
1174 }
1175 }
1176
1177 model->operators.emplace_back(op);
1178 return ::tensorflow::OkStatus();
1179}
1180
1181tensorflow::Status ConvertSplitOperator(
1182 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1183 const ModelFlags& model_flags, Model* model) {
1184 CHECK_EQ(node.op(), "Split");
1185 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1186 auto* op = new TensorFlowSplitOperator;
1187 op->inputs.push_back(node.input(0));
1188 op->inputs.push_back(node.input(1));
1189 const int num_split = GetIntAttr(node, "num_split");
1190 op->outputs.push_back(node.name());
1191 for (int i = 1; i < num_split; i++) {
1192 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1193 }
1194 op->num_split = num_split;
1195 model->operators.emplace_back(op);
1196 return ::tensorflow::OkStatus();
1197}
1198
1199tensorflow::Status ConvertSplitVOperator(
1200 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1201 const ModelFlags& model_flags, Model* model) {
1202 CHECK_EQ(node.op(), "SplitV");
1203 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1204 auto* op = new TensorFlowSplitVOperator;
1205 op->inputs.push_back(node.input(0));
1206 op->inputs.push_back(node.input(1));
1207 op->inputs.push_back(node.input(2));
1208 const int num_split = GetIntAttr(node, "num_split");
1209 op->outputs.push_back(node.name());
1210 for (int i = 1; i < num_split; i++) {
1211 op->outputs.push_back(absl::StrCat(node.name(), ":", i));
1212 }
1213 op->num_split = num_split;
1214 model->operators.emplace_back(op);
1215 return ::tensorflow::OkStatus();
1216}
1217
1218tensorflow::Status ConvertSwitchOperator(
1219 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1220 const ModelFlags& model_flags, Model* model) {
1221 CHECK_EQ(node.op(), "Switch");
1222 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1223 auto* op = new TensorFlowSwitchOperator;
1224 op->inputs.push_back(node.input(0));
1225 op->inputs.push_back(node.input(1));
1226 op->outputs.push_back(node.name());
1227 // Switch operators have two outputs: "name" and "name:1".
1228 op->outputs.push_back(node.name() + ":1");
1229 model->operators.emplace_back(op);
1230 return ::tensorflow::OkStatus();
1231}
1232
1233tensorflow::Status ConvertSoftmaxOperator(
1234 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1235 const ModelFlags& model_flags, Model* model) {
1236 CHECK_EQ(node.op(), "Softmax");
1237 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1238 const auto& input_name = node.input(0);
1239 auto* softmax = new SoftmaxOperator;
1240 softmax->inputs.push_back(input_name);
1241 softmax->outputs.push_back(node.name());
1242 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter.
1243 CHECK(!node.attr().count("beta")); // Stab in the dark, just in case.
1244 if (node.attr().count("_softmax_beta")) {
1245 softmax->beta = GetFloatAttr(node, "_softmax_beta");
1246 } else {
1247 softmax->beta = 1.f;
1248 }
1249 model->operators.emplace_back(softmax);
1250 return ::tensorflow::OkStatus();
1251}
1252
1253tensorflow::Status ConvertLRNOperator(
1254 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1255 const ModelFlags& model_flags, Model* model) {
1256 CHECK_EQ(node.op(), "LRN");
1257 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1258 const auto& input_name = node.input(0);
1259 auto* lrn = new LocalResponseNormalizationOperator;
1260 lrn->inputs.push_back(input_name);
1261 lrn->outputs.push_back(node.name());
1262 lrn->range = GetIntAttr(node, "depth_radius");
1263 lrn->bias = GetFloatAttr(node, "bias");
1264 lrn->alpha = GetFloatAttr(node, "alpha");
1265 lrn->beta = GetFloatAttr(node, "beta");
1266 model->operators.emplace_back(lrn);
1267 return ::tensorflow::OkStatus();
1268}
1269
1270tensorflow::Status ConvertMaxPoolOperator(
1271 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1272 const ModelFlags& model_flags, Model* model) {
1273 CHECK_EQ(node.op(), "MaxPool");
1274 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1275 const auto& input_name = node.input(0);
1276 // We only support NHWC, which is the default data_format.
1277 // So if data_format is not defined, we're all good.
1278 if (node.attr().count("data_format")) {
1279 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1280 }
1281 if (HasAttr(node, "T")) {
1282 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1283 } else {
1284 LOG(WARNING) << "Found MaxPool operator missing 'T' attribute";
1285 }
1286 auto* maxpool = new MaxPoolOperator;
1287 maxpool->inputs.push_back(input_name);
1288 maxpool->outputs.push_back(node.name());
1289 const auto& strides = GetListAttr(node, "strides");
1290 CHECK_EQ(strides.i_size(), 4);
1291 CHECK_EQ(strides.i(0), 1);
1292 CHECK_EQ(strides.i(3), 1);
1293 maxpool->stride_height = strides.i(1);
1294 maxpool->stride_width = strides.i(2);
1295 const auto& ksize = GetListAttr(node, "ksize");
1296 CHECK_EQ(ksize.i_size(), 4);
1297 CHECK_EQ(ksize.i(0), 1);
1298 CHECK_EQ(ksize.i(3), 1);
1299 maxpool->kheight = ksize.i(1);
1300 maxpool->kwidth = ksize.i(2);
1301 const auto& padding = GetStringAttr(node, "padding");
1302 if (padding == "SAME") {
1303 maxpool->padding.type = PaddingType::kSame;
1304 } else if (padding == "VALID") {
1305 maxpool->padding.type = PaddingType::kValid;
1306 } else {
1307 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1308 }
1309 model->operators.emplace_back(maxpool);
1310 return ::tensorflow::OkStatus();
1311}
1312
1313tensorflow::Status ConvertAvgPoolOperator(
1314 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1315 const ModelFlags& model_flags, Model* model) {
1316 CHECK_EQ(node.op(), "AvgPool");
1317 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1318 const auto& input_name = node.input(0);
1319 // We only support NHWC, which is the default data_format.
1320 // So if data_format is not defined, we're all good.
1321 if (node.attr().count("data_format")) {
1322 CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC");
1323 }
1324 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
1325 auto* avgpool = new AveragePoolOperator;
1326 avgpool->inputs.push_back(input_name);
1327 avgpool->outputs.push_back(node.name());
1328 const auto& strides = GetListAttr(node, "strides");
1329 CHECK_EQ(strides.i_size(), 4);
1330 CHECK_EQ(strides.i(0), 1);
1331 CHECK_EQ(strides.i(3), 1);
1332 avgpool->stride_height = strides.i(1);
1333 avgpool->stride_width = strides.i(2);
1334 const auto& ksize = GetListAttr(node, "ksize");
1335 CHECK_EQ(ksize.i_size(), 4);
1336 CHECK_EQ(ksize.i(0), 1);
1337 CHECK_EQ(ksize.i(3), 1);
1338 avgpool->kheight = ksize.i(1);
1339 avgpool->kwidth = ksize.i(2);
1340 const auto& padding = GetStringAttr(node, "padding");
1341 if (padding == "SAME") {
1342 avgpool->padding.type = PaddingType::kSame;
1343 } else if (padding == "VALID") {
1344 avgpool->padding.type = PaddingType::kValid;
1345 } else {
1346 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1347 }
1348 model->operators.emplace_back(avgpool);
1349 return ::tensorflow::OkStatus();
1350}
1351
1352tensorflow::Status ConvertBatchMatMulOperator(
1353 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1354 const ModelFlags& model_flags, Model* model) {
1355 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1356
1357 auto* batch_matmul = new BatchMatMulOperator;
1358 // https://www.tensorflow.org/versions/r0.12/api_docs/python/math_ops/matrix_math_functions
1359 if (HasAttr(node, "adj_x")) {
1360 batch_matmul->adj_x = GetBoolAttr(node, "adj_x");
1361 }
1362 if (HasAttr(node, "adj_y")) {
1363 batch_matmul->adj_y = GetBoolAttr(node, "adj_y");
1364 }
1365 batch_matmul->inputs = {node.input(0), node.input(1)};
1366 batch_matmul->outputs = {node.name()};
1367
1368 // For Flex mode. Please read the comments of the function.
1369 RetainTensorFlowNodeDef(node, batch_matmul);
1370
1371 model->operators.emplace_back(batch_matmul);
1372 return ::tensorflow::OkStatus();
1373}
1374
1375tensorflow::Status ConvertMatMulOperator(
1376 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1377 const ModelFlags& model_flags, Model* model) {
1378 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1379
1380 CHECK(!HasAttr(node, "adjoint_a") ||
1381 (GetBoolAttr(node, "adjoint_a") == false));
1382 CHECK(!HasAttr(node, "adjoint_b") ||
1383 (GetBoolAttr(node, "adjoint_b") == false));
1384
1385 auto* matmul = new TensorFlowMatMulOperator;
1386 if (HasAttr(node, "transpose_a")) {
1387 matmul->transpose_a = GetBoolAttr(node, "transpose_a");
1388 }
1389 if (HasAttr(node, "transpose_b")) {
1390 matmul->transpose_b = GetBoolAttr(node, "transpose_b");
1391 }
1392
1393 matmul->inputs = {node.input(0), node.input(1)};
1394 matmul->outputs = {node.name()};
1395 model->operators.emplace_back(matmul);
1396 return ::tensorflow::OkStatus();
1397}
1398
1399tensorflow::Status ConvertConcatOperator(
1400 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1401 const ModelFlags& model_flags, Model* model) {
1402 Operator* op = nullptr;
1403 if (node.op() == "Concat") {
1404 op = new TensorFlowConcatOperator;
1405 } else if (node.op() == "ConcatV2") {
1406 op = new TensorFlowConcatV2Operator;
1407 } else {
1408 LOG(FATAL) << "Expected Concat or ConcatV2";
1409 }
1410 const int num_inputs = GetInputsCount(node, tf_import_flags);
1411 QCHECK_GE(num_inputs, 2)
1412 << node.op()
1413 << " node expects at least 2 inputs other than control dependencies: "
1414 << node.DebugString();
1415 CHECK_EQ(num_inputs, 1 + GetIntAttr(node, "N"));
1416 for (int i = 0; i < num_inputs; ++i) {
1417 op->inputs.push_back(node.input(i));
1418 }
1419 op->outputs.push_back(node.name());
1420 model->operators.emplace_back(op);
1421 return ::tensorflow::OkStatus();
1422}
1423
1424tensorflow::Status ConvertMirrorPadOperator(
1425 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1426 const ModelFlags& model_flags, Model* model) {
1427 if (node.op() != "MirrorPad") {
1428 LOG(FATAL) << "Expected MirrorPad.";
1429 }
1430 const int num_inputs = GetInputsCount(node, tf_import_flags);
1431 CHECK_EQ(num_inputs, 2);
1432 auto* op = new MirrorPadOperator;
1433 for (int i = 0; i < num_inputs; ++i) {
1434 op->inputs.push_back(node.input(i));
1435 }
1436 op->outputs.push_back(node.name());
1437 const auto mode = GetStringAttr(node, "mode");
1438 if (mode == "REFLECT") {
1439 op->mode = toco::MirrorPadMode::kReflect;
1440 } else if (mode == "SYMMETRIC") {
1441 op->mode = toco::MirrorPadMode::kSymmetric;
1442 }
1443
1444 model->operators.emplace_back(op);
1445
1446 return ::tensorflow::OkStatus();
1447}
1448
1449static constexpr int kAnyNumInputs = -1;
1450
1451enum FlexSupport { kFlexOk, kFlexNotOk };
1452
1453// This method supports simple operators without additional attributes.
1454// Converts a simple operator that takes no attributes. The list of inputs is
1455// taken from the given NodeDef, and its number must match NumInputs, unless
1456// kAnyNumInputs is passed in. If kFlexOk is passed in the resulting operator
1457// will be eligible for being exported as a flex op.
1458template <typename Op, int NumInputs, int NumOutputs, FlexSupport flex>
1459tensorflow::Status ConvertSimpleOperatorGeneric(
1460 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1461 const ModelFlags& model_flags, Model* model) {
1462 if (NumInputs != kAnyNumInputs) {
1463 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, NumInputs));
1464 }
1465 auto* op = new Op;
1466 const int num_inputs = GetInputsCount(node, tf_import_flags);
1467 for (int i = 0; i < num_inputs; ++i) {
1468 op->inputs.push_back(node.input(i));
1469 }
1470 op->outputs.push_back(node.name());
1471 if (NumOutputs > 1) {
1472 for (int i = 1; i < NumOutputs; ++i) {
1473 op->outputs.push_back(node.name() + ":" + std::to_string(i));
1474 }
1475 }
1476
1477 if (flex == kFlexOk) {
1478 RetainTensorFlowNodeDef(node, op);
1479 }
1480
1481 model->operators.emplace_back(op);
1482 return ::tensorflow::OkStatus();
1483}
1484
1485// Convert a simple operator which is not valid as a flex op.
1486template <typename Op, int NumInputs, int NumOutputs>
1487tensorflow::Status ConvertSimpleOperator(
1488 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1489 const ModelFlags& model_flags, Model* model) {
1490 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexNotOk>(
1491 node, tf_import_flags, model_flags, model);
1492}
1493
1494// Convert a simple operator which is valid as a flex op.
1495template <typename Op, int NumInputs, int NumOutputs>
1496tensorflow::Status ConvertSimpleOperatorFlexOk(
1497 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1498 const ModelFlags& model_flags, Model* model) {
1499 return ConvertSimpleOperatorGeneric<Op, NumInputs, NumOutputs, kFlexOk>(
1500 node, tf_import_flags, model_flags, model);
1501}
1502
1503// Same as ConvertConstOperator, but revert to ConvertUnsupportedOperator if
1504// the types are not supported. Converting Const operators here avoids
1505// expensive copies of the protocol buffers downstream in the flex delegate.
1506tensorflow::Status ConditionallyConvertConstOperator(
1507 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1508 const ModelFlags& model_flags, Model* model) {
1509 // We avoid incomplete and zero shapes because the resulting arrays
1510 // are not completely compatible with Eager/TensorFlow.
1511 const auto& tensor = GetTensorAttr(node, "value");
1512 const auto& shape = tensor.tensor_shape();
1513 for (const auto& dim : shape.dim()) {
1514 if (dim.size() <= 0) {
1515 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1516 model);
1517 }
1518 }
1519 switch (GetDataTypeAttr(node, "dtype")) {
1520 case DT_FLOAT:
1521 case DT_INT32:
1522 case DT_QUINT8:
1523 case DT_INT64:
1524 case DT_STRING:
1525 case DT_BOOL:
1526 case DT_COMPLEX64:
1527 return ConvertConstOperator(node, tf_import_flags, model_flags, model);
1528 default:
1529 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
1530 model);
1531 }
1532}
1533
1534tensorflow::Status ConvertStridedSliceOperator(
1535 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1536 const ModelFlags& model_flags, Model* model) {
1537 CHECK_EQ(node.op(), "StridedSlice");
1538 // TODO(soroosh): The 4th input (strides) should be e optional, to be
1539 // consistent with TF.
1540 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
1541
1542 auto* op = new StridedSliceOperator;
1543 for (const auto& input : node.input()) {
1544 op->inputs.push_back(input);
1545 }
1546 op->outputs.push_back(node.name());
1547
1548 op->begin_mask =
1549 HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
1550 op->ellipsis_mask =
1551 HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
1552 op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
1553 op->new_axis_mask =
1554 HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
1555 op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
1556 ? GetIntAttr(node, "shrink_axis_mask")
1557 : 0;
1558
1559 model->operators.emplace_back(op);
1560 return ::tensorflow::OkStatus();
1561}
1562
1563tensorflow::Status ConvertPlaceholderOperator(
1564 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1565 const ModelFlags& model_flags, Model* model) {
1566 CHECK(node.op() == "Placeholder" || node.op() == "LegacyFedInput");
1567 if (node.op() == "Placeholder") {
1568 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 0));
1569 }
1570
1571 bool inside_input_arrays = false;
1572 for (const auto& input_array : model_flags.input_arrays()) {
1573 if (node.name() == input_array.name()) {
1574 inside_input_arrays = true;
1575 break;
1576 }
1577 }
1578
1579 if (!inside_input_arrays) {
1580 model->AddInvalidInputArray(node.name());
1581 }
1582
1583 auto& array = model->GetOrCreateArray(node.name());
1584 if (node.attr().count("dtype")) {
1585 array.data_type = ConvertDataType(GetDataTypeAttr(node, "dtype"));
1586 }
1587 if (node.attr().count("shape")) {
1588 const auto& shape = GetShapeAttr(node, "shape");
1589 auto num_dims = shape.dim_size();
1590 // TODO(b/62716978): This logic needs to be revisited. During dims
1591 // refactoring it is an interim fix.
1592 if (num_dims > 0 && !HasWildcardDimension(shape)) {
1593 auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
1594 dst_array_dims.resize(num_dims);
1595 for (std::size_t i = 0; i < num_dims; i++) {
1596 dst_array_dims[i] = shape.dim(i).size();
1597 }
1598 }
1599 }
1600 return ::tensorflow::OkStatus();
1601}
1602
1603tensorflow::Status ConvertNoOpOperator(
1604 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1605 const ModelFlags& model_flags, Model* model) {
1606 return ::tensorflow::OkStatus();
1607}
1608
1609tensorflow::Status ConvertCastOperator(
1610 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1611 const ModelFlags& model_flags, Model* model) {
1612 CHECK_EQ(node.op(), "Cast");
1613 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1614 const auto tf_src_dtype = GetDataTypeAttr(node, "SrcT");
1615 const auto tf_dst_dtype = GetDataTypeAttr(node, "DstT");
1616 auto* op = new CastOperator;
1617 op->src_data_type = ConvertDataType(tf_src_dtype);
1618 op->dst_data_type = ConvertDataType(tf_dst_dtype);
1619 op->inputs.push_back(node.input(0));
1620 op->outputs.push_back(node.name());
1621 model->operators.emplace_back(op);
1622 return ::tensorflow::OkStatus();
1623}
1624
1625tensorflow::Status ConvertFloorOperator(
1626 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1627 const ModelFlags& model_flags, Model* model) {
1628 CHECK_EQ(node.op(), "Floor");
1629 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1630 const auto data_type = GetDataTypeAttr(node, "T");
1631 CHECK(data_type == DT_FLOAT);
1632 auto* op = new FloorOperator;
1633 op->inputs.push_back(node.input(0));
1634 op->outputs.push_back(node.name());
1635 model->operators.emplace_back(op);
1636 return ::tensorflow::OkStatus();
1637}
1638
1639tensorflow::Status ConvertCeilOperator(
1640 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1641 const ModelFlags& model_flags, Model* model) {
1642 CHECK_EQ(node.op(), "Ceil");
1643 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1644 const auto data_type = GetDataTypeAttr(node, "T");
1645 CHECK(data_type == DT_FLOAT);
1646 auto* op = new CeilOperator;
1647 op->inputs.push_back(node.input(0));
1648 op->outputs.push_back(node.name());
1649 model->operators.emplace_back(op);
1650 return ::tensorflow::OkStatus();
1651}
1652
1653tensorflow::Status ConvertRoundOperator(
1654 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1655 const ModelFlags& model_flags, Model* model) {
1656 CHECK_EQ(node.op(), "Round");
1657 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
1658 const auto data_type = GetDataTypeAttr(node, "T");
1659 CHECK(data_type == DT_FLOAT);
1660 auto* op = new RoundOperator;
1661 op->inputs.push_back(node.input(0));
1662 op->outputs.push_back(node.name());
1663 model->operators.emplace_back(op);
1664 return ::tensorflow::OkStatus();
1665}
1666
1667tensorflow::Status ConvertGatherOperator(
1668 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1669 const ModelFlags& model_flags, Model* model) {
1670 CHECK(node.op() == "Gather" || node.op() == "GatherV2");
1671 if (node.op() == "Gather")
1672 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1673 if (node.op() == "GatherV2")
1674 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1675 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1676 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1677 auto* op = new GatherOperator;
1678 op->inputs.push_back(node.input(0));
1679 op->inputs.push_back(node.input(1));
1680 if (node.input_size() >= 3) {
1681 // GatherV2 form where we are provided an axis. It may be either a constant
1682 // or runtime defined value, so we just wire up the array and let
1683 // ResolveGatherAttributes take care of it later on.
1684 const auto axis_data_type = GetDataTypeAttr(node, "Taxis");
1685 CHECK(axis_data_type == DT_INT32 || axis_data_type == DT_INT64);
1686 op->inputs.push_back(node.input(2));
1687 } else {
1688 // Gather form that assumes axis=0.
1689 op->axis = {0};
1690 }
1691 op->outputs.push_back(node.name());
1692 model->operators.emplace_back(op);
1693 return ::tensorflow::OkStatus();
1694}
1695
1696tensorflow::Status ConvertGatherNdOperator(
1697 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1698 const ModelFlags& model_flags, Model* model) {
1699 CHECK_EQ(node.op(), "GatherNd");
1700 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1701 const auto indices_data_type = GetDataTypeAttr(node, "Tindices");
1702 CHECK(indices_data_type == DT_INT32 || indices_data_type == DT_INT64);
1703 auto* op = new GatherNdOperator;
1704 op->inputs.push_back(node.input(0));
1705 op->inputs.push_back(node.input(1));
1706 op->outputs.push_back(node.name());
1707 model->operators.emplace_back(op);
1708 return ::tensorflow::OkStatus();
1709}
1710
1711template <typename Op>
1712tensorflow::Status ConvertArgMinMaxOperator(
1713 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1714 const ModelFlags& model_flags, Model* model) {
1715 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1716 const auto axis_data_type =
1717 HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
1718 const auto output_type = HasAttr(node, "output_type")
1719 ? GetDataTypeAttr(node, "output_type")
1720 : DT_INT64;
1721 CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
1722 CHECK(output_type == DT_INT64 || output_type == DT_INT32);
1723 auto* op = new Op;
1724 op->output_data_type = ConvertDataType(output_type);
1725 op->inputs.push_back(node.input(0));
1726 op->inputs.push_back(node.input(1));
1727 op->outputs.push_back(node.name());
1728 model->operators.emplace_back(op);
1729 return ::tensorflow::OkStatus();
1730}
1731
1732tensorflow::Status ConvertArgMaxOperator(
1733 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1734 const ModelFlags& model_flags, Model* model) {
1735 CHECK_EQ(node.op(), "ArgMax");
1736 return ConvertArgMinMaxOperator<ArgMaxOperator>(node, tf_import_flags,
1737 model_flags, model);
1738}
1739
1740tensorflow::Status ConvertArgMinOperator(
1741 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1742 const ModelFlags& model_flags, Model* model) {
1743 CHECK_EQ(node.op(), "ArgMin");
1744 return ConvertArgMinMaxOperator<ArgMinOperator>(node, tf_import_flags,
1745 model_flags, model);
1746}
1747
1748tensorflow::Status ConvertResizeBilinearOperator(
1749 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1750 const ModelFlags& model_flags, Model* model) {
1751 CHECK_EQ(node.op(), "ResizeBilinear");
1752 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1753 auto* op = new ResizeBilinearOperator;
1754
1755 op->align_corners = false;
1756 op->half_pixel_centers = false;
1757 if (HasAttr(node, "align_corners")) {
1758 op->align_corners = GetBoolAttr(node, "align_corners");
1759 }
1760 if (HasAttr(node, "half_pixel_centers")) {
1761 op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1762 }
1763
1764 op->inputs.push_back(node.input(0));
1765 op->inputs.push_back(node.input(1));
1766 op->outputs.push_back(node.name());
1767 model->operators.emplace_back(op);
1768 return ::tensorflow::OkStatus();
1769}
1770
1771tensorflow::Status ConvertResizeNearestNeighborOperator(
1772 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1773 const ModelFlags& model_flags, Model* model) {
1774 CHECK_EQ(node.op(), "ResizeNearestNeighbor");
1775 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1776 auto* op = new ResizeNearestNeighborOperator;
1777
1778 op->align_corners = false;
1779 op->half_pixel_centers = false;
1780 if (HasAttr(node, "align_corners")) {
1781 op->align_corners = GetBoolAttr(node, "align_corners");
1782 }
1783 if (HasAttr(node, "half_pixel_centers")) {
1784 op->half_pixel_centers = GetBoolAttr(node, "half_pixel_centers");
1785 }
1786
1787 op->inputs.push_back(node.input(0));
1788 op->inputs.push_back(node.input(1));
1789 op->outputs.push_back(node.name());
1790 model->operators.emplace_back(op);
1791 return ::tensorflow::OkStatus();
1792}
1793
1794tensorflow::Status ConvertBatchNormWithGlobalNormalizationOperator(
1795 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1796 const ModelFlags& model_flags, Model* model) {
1797 CHECK_EQ(node.op(), "BatchNormWithGlobalNormalization");
1798 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1799
1800 // TODO(ahentz): to really match tensorflow we need to add variance_epsilon
1801 // to the input, before feeding it into TensorFlowRsqrtOperator.
1802 // CHECK_EQ(GetFloatAttr(node, "variance_epsilon"), 0.001f);
1803
1804 std::string multiplier = node.name() + "_mul";
1805 if (GetBoolAttr(node, "scale_after_normalization")) {
1806 // Create graph:
1807 // v -> RSQRT ->
1808 // MUL -> multiplier
1809 // gamma ----->
1810 std::string rsqrt = node.name() + "_rsqrt";
1811
1812 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1813 rsqrt_op->inputs.push_back(node.input(2));
1814 rsqrt_op->outputs.push_back(rsqrt);
1815 model->operators.emplace_back(rsqrt_op);
1816
1817 auto* mul_op = new MulOperator;
1818 mul_op->inputs.push_back(rsqrt);
1819 mul_op->inputs.push_back(node.input(4));
1820 mul_op->outputs.push_back(multiplier);
1821 model->operators.emplace_back(mul_op);
1822 } else {
1823 // Create graph:
1824 // v -> RSQRT -> multiplier
1825 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1826 rsqrt_op->inputs.push_back(node.input(2));
1827 rsqrt_op->outputs.push_back(multiplier);
1828 model->operators.emplace_back(rsqrt_op);
1829 }
1830
1831 auto* op = new BatchNormalizationOperator;
1832 op->global_normalization = true;
1833
1834 op->inputs.push_back(node.input(0));
1835 op->inputs.push_back(node.input(1));
1836 op->inputs.push_back(multiplier);
1837 op->inputs.push_back(node.input(3));
1838 op->outputs.push_back(node.name());
1839
1840 model->operators.emplace_back(op);
1841 return ::tensorflow::OkStatus();
1842}
1843
1844tensorflow::Status ConvertFusedBatchNormOperator(
1845 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1846 const ModelFlags& model_flags, Model* model) {
1847 CHECK((node.op() == "FusedBatchNorm") || (node.op() == "FusedBatchNormV3"));
1848 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 5));
1849
1850 // Declare shortcuts for the inputs.
1851 const std::string& gamma_input = node.input(1);
1852 const std::string& beta_input = node.input(2);
1853 const std::string& moving_mean_input = node.input(3);
1854 const std::string& moving_variance_input = node.input(4);
1855
1856 // Create an array holding the epsilon value (typically, 0.001).
1857 const std::string epsilon_array_name =
1858 CreateConstArray<ArrayDataType::kFloat>(model,
1859 node.name() + "_epsilon_array",
1860 {GetFloatAttr(node, "epsilon")});
1861
1862 // Add epsilon to the moving variance.
1863 const std::string epsilon_add_op_name = node.name() + "_epsilon";
1864 auto* epsilon_add_op = new AddOperator;
1865 epsilon_add_op->inputs.push_back(moving_variance_input);
1866 epsilon_add_op->inputs.push_back(epsilon_array_name);
1867 epsilon_add_op->outputs.push_back(epsilon_add_op_name);
1868 model->operators.emplace_back(epsilon_add_op);
1869
1870 // Take the inverse square root of the (variance + epsilon).
1871 const std::string rsqrt_op_name = node.name() + "_rsqrt";
1872 auto* rsqrt_op = new TensorFlowRsqrtOperator;
1873 rsqrt_op->inputs.push_back(epsilon_add_op_name);
1874 rsqrt_op->outputs.push_back(rsqrt_op_name);
1875 model->operators.emplace_back(rsqrt_op);
1876
1877 // Multiply the result by gamma.
1878 const std::string multiplier = node.name() + "_mul";
1879 auto* mul_op = new MulOperator;
1880 mul_op->inputs.push_back(rsqrt_op_name);
1881 mul_op->inputs.push_back(gamma_input);
1882 mul_op->outputs.push_back(multiplier);
1883 model->operators.emplace_back(mul_op);
1884
1885 // Now we have all required inputs for the BatchNormalizationOperator.
1886 auto* op = new BatchNormalizationOperator;
1887 op->global_normalization = true;
1888
1889 op->inputs.push_back(node.input(0));
1890 op->inputs.push_back(moving_mean_input);
1891 op->inputs.push_back(multiplier);
1892 op->inputs.push_back(beta_input);
1893 op->outputs.push_back(node.name());
1894
1895 model->operators.emplace_back(op);
1896 return ::tensorflow::OkStatus();
1897}
1898
1899tensorflow::Status ConvertSpaceToBatchNDOperator(
1900 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1901 const ModelFlags& model_flags, Model* model) {
1902 CHECK_EQ(node.op(), "SpaceToBatchND");
1903 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1904 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1905 CHECK_EQ(GetDataTypeAttr(node, "Tpaddings"), DT_INT32);
1906 auto* op = new SpaceToBatchNDOperator;
1907 op->inputs.push_back(node.input(0));
1908 op->inputs.push_back(node.input(1));
1909 op->inputs.push_back(node.input(2));
1910 op->outputs.push_back(node.name());
1911 model->operators.emplace_back(op);
1912 return ::tensorflow::OkStatus();
1913}
1914
1915tensorflow::Status ConvertBatchToSpaceNDOperator(
1916 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1917 const ModelFlags& model_flags, Model* model) {
1918 CHECK_EQ(node.op(), "BatchToSpaceND");
1919 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1920 CHECK_EQ(GetDataTypeAttr(node, "Tblock_shape"), DT_INT32);
1921 CHECK_EQ(GetDataTypeAttr(node, "Tcrops"), DT_INT32);
1922 auto* op = new BatchToSpaceNDOperator;
1923 op->inputs.push_back(node.input(0));
1924 op->inputs.push_back(node.input(1));
1925 op->inputs.push_back(node.input(2));
1926 op->outputs.push_back(node.name());
1927 model->operators.emplace_back(op);
1928 return ::tensorflow::OkStatus();
1929}
1930
1931template <typename T>
1932tensorflow::Status ConvertReduceOperator(
1933 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1934 const ModelFlags& model_flags, Model* model) {
1935 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
1936 auto* op = new T;
1937 op->inputs.push_back(node.input(0));
1938 op->inputs.push_back(node.input(1));
1939 op->outputs.push_back(node.name());
1940 model->operators.emplace_back(op);
1941 if (HasAttr(node, "keepdims")) {
1942 op->keep_dims = GetBoolAttr(node, "keepdims");
1943 } else if (HasAttr(node, "keep_dims")) {
1944 op->keep_dims = GetBoolAttr(node, "keep_dims");
1945 }
1946 return ::tensorflow::OkStatus();
1947}
1948
1949// TODO(b/139320642): Add test when fused op is supported.
1950tensorflow::Status ConvertSvdfOperator(
1951 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1952 const ModelFlags& model_flags, Model* model) {
1953 CHECK_EQ(node.op(), "Svdf");
1954 const int input_size = GetInputsCount(node, tf_import_flags);
1955 QCHECK(input_size == 4 || input_size == 5)
1956 << "Svdf node expects 3 or 4 inputs other than control dependencies: "
1957 << node.DebugString();
1958 bool has_bias = (input_size == 5);
1959 auto* op = new SvdfOperator;
1960 int index = 0;
1961 op->inputs.push_back(node.input(index++));
1962 op->inputs.push_back(node.input(index++));
1963 op->inputs.push_back(node.input(index++));
1964 if (has_bias) {
1965 op->inputs.push_back(node.input(index++));
1966 }
1967 op->inputs.push_back(node.input(index));
1968 op->outputs.push_back(node.name());
1969 if (node.attr().at("ActivationFunction").s() == "Relu") {
1970 op->fused_activation_function = FusedActivationFunctionType::kRelu;
1971 } else {
1972 op->fused_activation_function = FusedActivationFunctionType::kNone;
1973 }
1974 op->rank = node.attr().at("Rank").i();
1975 model->operators.emplace_back(op);
1976 return ::tensorflow::OkStatus();
1977}
1978
1979// This is just bare bones support to get the shapes to propagate.
1980tensorflow::Status ConvertTransposeConvOperator(
1981 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
1982 const ModelFlags& model_flags, Model* model) {
1983 CHECK_EQ(node.op(), "Conv2DBackpropInput");
1984 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
1985 auto* op = new TransposeConvOperator;
1986 op->inputs.push_back(node.input(0));
1987 op->inputs.push_back(node.input(1));
1988 op->inputs.push_back(node.input(2));
1989 op->outputs.push_back(node.name());
1990 const auto& strides = GetListAttr(node, "strides");
1991 op->stride_height = strides.i(1);
1992 op->stride_width = strides.i(2);
1993 CHECK_EQ(strides.i_size(), 4)
1994 << "Can only import TransposeConv ops with 4D strides. TensorFlow op \""
1995 << node.name() << "\" has " << strides.i_size() << "D strides.";
1996 CHECK((strides.i(0) == 1) && (strides.i(3) == 1))
1997 << "Can only import TransposeConv ops with striding along the height "
1998 "(1st) or width (2nd) axis. TensorFlow op \""
1999 << node.name() << "\" had strides:[ " << strides.i(0) << ", "
2000 << strides.i(1) << ", " << strides.i(2) << ", " << strides.i(3) << "].";
2001 op->stride_height = strides.i(1);
2002 op->stride_width = strides.i(2);
2003 if (HasAttr(node, "dilations")) {
2004 const auto& dilations = GetListAttr(node, "dilations");
2005 CHECK_EQ(dilations.i_size(), 4)
2006 << "Dilation unsupported in TransposeConv. TensorFlow op \""
2007 << node.name() << "\" had dilations";
2008 CHECK((dilations.i(0) == 1) && (dilations.i(1) == 1) &&
2009 (dilations.i(2) == 1) && (dilations.i(3) == 1))
2010 << "Dilation unsupported in TransposeConv. TensorFlow op \""
2011 << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
2012 << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
2013 << "].";
2014 }
2015
2016 const std::string& weights_name = node.input(TransposeConvOperator::WEIGHTS);
2017 const std::string& transposed_weights_name = weights_name + "_transposed";
2018 // Check if a TransposeOperator was already created for these weights
2019 // (can happen when multiple layers share the same weights).
2020 const Operator* existing_transpose =
2021 GetOpWithOutput(*model, transposed_weights_name);
2022 if (existing_transpose) {
2023 CHECK(existing_transpose->type == OperatorType::kTranspose);
2024 } else {
2025 // Transpose weights from HWOI order to OHWI order, which is more efficient
2026 // for computation. (Note that TensorFlow considers the order as HWIO
2027 // because they consider this a backward conv, inverting the sense of
2028 // input/output.)
2029 TransposeOperator* transpose = new TransposeOperator;
2030 std::string perm_array = CreateConstArray<ArrayDataType::kInt32>(
2031 model, node.name() + "_transpose_perm", {2, 0, 1, 3});
2032 transpose->inputs = {weights_name, perm_array};
2033 transpose->outputs = {transposed_weights_name};
2034 model->operators.emplace_back(transpose);
2035 }
2036 op->inputs[1] = transposed_weights_name;
2037
2038 auto const& padding = GetStringAttr(node, "padding");
2039 if (padding == "SAME") {
2040 op->padding.type = PaddingType::kSame;
2041 } else if (padding == "VALID") {
2042 op->padding.type = PaddingType::kValid;
2043 } else {
2044 LOG(FATAL) << "Only SAME and VALID padding supported on "
2045 "Conv2DBackpropInput nodes.";
2046 }
2047 model->operators.emplace_back(op);
2048 return ::tensorflow::OkStatus();
2049}
2050
2051tensorflow::Status ConvertRangeOperator(
2052 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2053 const ModelFlags& model_flags, Model* model) {
2054 CHECK_EQ(node.op(), "Range");
2055 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 3));
2056 auto* op = new RangeOperator;
2057 if (HasAttr(node, "Tidx")) {
2058 const auto dtype = toco::GetDataTypeAttr(node, "Tidx");
2059 CHECK(dtype == DT_UINT8 || dtype == DT_INT32 || dtype == DT_INT64 ||
2060 dtype == DT_FLOAT);
2061 op->dtype = ConvertDataType(dtype);
2062 }
2063 op->inputs.push_back(node.input(0));
2064 op->inputs.push_back(node.input(1));
2065 op->inputs.push_back(node.input(2));
2066 op->outputs.push_back(node.name());
2067
2068 model->operators.emplace_back(op);
2069 return ::tensorflow::OkStatus();
2070}
2071
2072// Note that it's easy to confuse/conflate "Stack" and "Pack" operators, but
2073// they aren't the same thing. tf.stack results in a "Pack" operator. "Stack"
2074// operators also exist, but involve manipulating the TF runtime stack, and are
2075// not directly related to tf.stack() usage.
2076tensorflow::Status ConvertPackOperator(
2077 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2078 const ModelFlags& model_flags, Model* model) {
2079 CHECK_EQ(node.op(), "Pack");
2080 auto op = std::make_unique<PackOperator>();
2081 const int num_inputs = GetInputsCount(node, tf_import_flags);
2082 QCHECK_GE(num_inputs, 1)
2083 << node.op()
2084 << " node expects at least 1 input other than control dependencies: "
2085 << node.DebugString();
2086 CHECK_EQ(num_inputs, GetIntAttr(node, "N"));
2087 for (int i = 0; i < num_inputs; ++i) {
2088 op->inputs.push_back(node.input(i));
2089 }
2090 op->values_count = HasAttr(node, "N") ? GetIntAttr(node, "N") : num_inputs;
2091 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2092 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2093 op->outputs.push_back(node.name());
2094 model->operators.emplace_back(std::move(op));
2095 return ::tensorflow::OkStatus();
2096}
2097
2098tensorflow::Status ConvertUnpackOperator(
2099 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2100 const ModelFlags& model_flags, Model* model) {
2101 CHECK_EQ(node.op(), "Unpack");
2102 auto op = std::make_unique<UnpackOperator>();
2103 const int num_inputs = GetInputsCount(node, tf_import_flags);
2104 QCHECK_EQ(num_inputs, 1);
2105 op->inputs.push_back(node.input(0));
2106 op->num = GetIntAttr(node, "num");
2107 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
2108 op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
2109
2110 op->outputs.push_back(node.name()); // Implicit :0.
2111 for (int i = 1; i < op->num; ++i) {
2112 op->outputs.push_back(node.name() + ":" + std::to_string(i));
2113 }
2114 model->operators.emplace_back(std::move(op));
2115 return ::tensorflow::OkStatus();
2116}
2117
2118// Some TensorFlow ops only occur in graph cycles, representing
2119// control flow. We do not currently support control flow, so we wouldn't
2120// be able to fully support such graphs, including performing inference,
2121// anyway. However, rather than erroring out early on graphs being cyclic,
2122// it helps to at least support these just enough to allow getting a
2123// graph visualization. This is not trivial, as we require graphs to be
2124// acyclic aside from RNN back-edges. The solution is to special-case
2125// such ops as RNN back-edges, which is technically incorrect (does not
2126// allow representing the op's semantics) but good enough to get a
2127// graph visualization.
2128tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
2129 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2130 const ModelFlags& model_flags, Model* model) {
2131 // At the moment, the only type of operator special-cased in this way is
2132 // NextIteration, occurring only in control-flow cycles.
2133 CHECK_EQ(node.op(), "NextIteration");
2134 CHECK_EQ(node.input_size(), 1);
2135 auto* rnn_state = model->flags.add_rnn_states();
2136 // This RNN state is not explicitly created by the user, so it's
2137 // OK for some later graph transformation to discard it.
2138 rnn_state->set_discardable(true);
2139 rnn_state->set_state_array(node.name());
2140 rnn_state->set_back_edge_source_array(node.input(0));
2141 // TODO(tianjuny): Temporary set the size to 1 to avoid transient array
2142 // allocation crash. The real value should depend on the hidden_size of RNN.
2143 rnn_state->set_size(1);
2144 return ::tensorflow::OkStatus();
2145}
2146
2147tensorflow::Status ConvertShapeOperator(
2148 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2149 const ModelFlags& model_flags, Model* model) {
2150 CHECK_EQ(node.op(), "Shape");
2151 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2152 const auto out_type =
2153 HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
2154 CHECK(out_type == DT_INT64 || out_type == DT_INT32);
2155 auto op = std::make_unique<TensorFlowShapeOperator>();
2156 op->output_data_type = ConvertDataType(out_type);
2157 op->inputs.push_back(node.input(0));
2158 op->outputs.push_back(node.name());
2159 model->operators.push_back(std::move(op));
2160 return ::tensorflow::OkStatus();
2161}
2162
2163tensorflow::Status ConvertReverseSequenceOperator(
2164 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2165 const ModelFlags& model_flags, Model* model) {
2166 CHECK_EQ(node.op(), "ReverseSequence");
2167 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2168 auto op = std::make_unique<ReverseSequenceOperator>();
2169 if (HasAttr(node, "seq_dim")) {
2170 op->seq_dim = GetIntAttr(node, "seq_dim");
2171 }
2172 // In tf.reverse_sequence, batch_dim defaults to 0.
2173 op->batch_dim =
2174 HasAttr(node, "batch_dim") ? GetIntAttr(node, "batch_dim") : 0;
2175 const int num_inputs = GetInputsCount(node, tf_import_flags);
2176 for (int i = 0; i < num_inputs; ++i) {
2177 op->inputs.push_back(node.input(i));
2178 }
2179 op->outputs.push_back(node.name());
2180 model->operators.push_back(std::move(op));
2181 return ::tensorflow::OkStatus();
2182}
2183
2184void StripCaretFromArrayNames(Model* model) {
2185 for (auto& op : model->operators) {
2186 for (auto& input : op->inputs) {
2187 input = std::string(absl::StripPrefix(input, "^"));
2188 }
2189 for (auto& output : op->outputs) {
2190 output = std::string(absl::StripPrefix(output, "^"));
2191 }
2192 }
2193 for (auto& array : model->GetArrayMap()) {
2194 if (absl::StartsWith(array.first, "^")) {
2195 LOG(FATAL) << "What?";
2196 }
2197 }
2198}
2199
2200void StripZeroOutputIndexFromInputs(NodeDef* node) {
2201 for (auto& input : *node->mutable_input()) {
2202 input = std::string(absl::StripSuffix(input, ":0"));
2203 }
2204}
2205
2206// In TensorFlow GraphDef, when a node has multiple outputs, they are named
2207// name:0, name:1, ...
2208// where 'name' is the node's name(). Just 'name' is an equivalent shorthand
2209// form for name:0.
2210// A TensorFlow GraphDef does not explicitly list all the outputs of each node
2211// (unlike inputs), it being implied by the node's name and operator type
2212// (the latter implies the number of outputs).
2213// This makes it non-trivial for us to reconstruct the list of all arrays
2214// present in the graph and, for each operator, the list of its outputs.
2215// We do that by taking advantage of the fact that
2216// at least each node lists explicitly its inputs, so after we've loaded
2217// all nodes, we can use that information.
2218void AddExtraOutputs(Model* model) {
2219 // Construct the list of all arrays consumed by anything in the graph.
2220 std::vector<std::string> consumed_arrays;
2221 // Add arrays consumed by an op.
2222 for (const auto& consumer_op : model->operators) {
2223 for (const std::string& input : consumer_op->inputs) {
2224 consumed_arrays.push_back(input);
2225 }
2226 }
2227 // Add global outputs of the model.
2228 for (const std::string& output_array : model->flags.output_arrays()) {
2229 consumed_arrays.push_back(output_array);
2230 }
2231 // Add arrays consumed by a RNN back-edge.
2232 for (const auto& rnn_state : model->flags.rnn_states()) {
2233 consumed_arrays.push_back(rnn_state.back_edge_source_array());
2234 }
2235 // Now add operator outputs so that all arrays that are consumed,
2236 // are produced.
2237 for (const std::string& consumed_array : consumed_arrays) {
2238 // Test if consumed_array is already the output of some op.
2239 // This has occurred in a model where separate nodes had names of the form
2240 // foo:$i with the same base name foo.
2241 if (GetOpWithOutput(*model, consumed_array)) {
2242 continue;
2243 }
2244 // Split the consumed array name into the form name:output_index.
2245 const std::vector<std::string>& split = absl::StrSplit(consumed_array, ':');
2246 // If not of the form name:output_index, then this is not an additional
2247 // output of a node with multiple outputs, so nothing to do here.
2248 if (split.size() != 2) {
2249 continue;
2250 }
2251 int output_index = 0;
2252 if (!absl::SimpleAtoi(split[1], &output_index)) {
2253 continue;
2254 }
2255 // Each op is initially recorded as producing at least the array that
2256 // has its name. We use that to identify the producer node.
2257 auto* producer_op = GetOpWithOutput(*model, split[0]);
2258 if (!producer_op) {
2259 continue;
2260 }
2261 // Add extra outputs to that producer node, all the way to the
2262 // output_index.
2263 while (producer_op->outputs.size() <= output_index) {
2264 using toco::port::StringF;
2265 producer_op->outputs.push_back(
2266 StringF("%s:%d", split[0], producer_op->outputs.size()));
2267 }
2268 }
2269}
2270
2271bool InlineAllFunctions(GraphDef* graphdef) {
2272 if (graphdef->library().function().empty()) {
2273 VLOG(kLogLevelModelUnchanged) << "No functions to inline.";
2274 return false;
2275 }
2276
2277 // Override "_noinline" attribute on all functions
2278 GraphDef graphdef_copy(*graphdef);
2279 for (auto& function :
2280 (*graphdef_copy.mutable_library()->mutable_function())) {
2281 auto* attributes = function.mutable_attr();
2282 if (attributes->count(tensorflow::kNoInlineAttr) != 0) {
2283 (*attributes)[tensorflow::kNoInlineAttr].set_b(false);
2284 }
2285 }
2286
2287 // Construct minimum resources needed to use ExpandInlineFunctions().
2288 tensorflow::SessionOptions options;
2289 auto* device_count = options.config.mutable_device_count();
2290 device_count->insert({"CPU", 1});
2291 std::vector<std::unique_ptr<tensorflow::Device>> devices;
2292 TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
2293 options, "/job:localhost/replica:0/task:0", &devices));
2294
2295 tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
2296 graphdef_copy.library());
2297 tensorflow::StaticDeviceMgr device_mgr(std::move(devices));
2298 tensorflow::ProcessFunctionLibraryRuntime pflr(
2299 &device_mgr, tensorflow::Env::Default(), &options.config,
2300 TF_GRAPH_DEF_VERSION, &fld,
2301 options.config.graph_options().optimizer_options(), nullptr);
2302 tensorflow::FunctionLibraryRuntime* flr;
2303 flr = pflr.GetFLR("/job:localhost/replica:0/task:0/cpu:0");
2304
2305 tensorflow::Graph graph(fld);
2306 tensorflow::ImportGraphDefOptions gc_opts;
2307 gc_opts.validate_shape = false;
2308 const auto& tf_convert_status = tensorflow::ImportGraphDef(
2309 gc_opts, graphdef_copy, &graph, nullptr, nullptr);
2310 if (!tf_convert_status.ok()) {
2311 LOG(ERROR) << "tensorflow::ImportGraphDef failed with status: "
2312 << tf_convert_status.ToString();
2313 return false;
2314 }
2315
2316 // Iterate over the graph until there are no more nodes to be inlined.
2317 bool graph_modified = false;
2318 while (tensorflow::ExpandInlineFunctions(flr, &graph)) {
2319 graph_modified = true;
2320 }
2321
2322 // Output inlined graph
2323 if (graph_modified) {
2324 LOG(INFO) << "Found and inlined TensorFlow functions.";
2325 graph.ToGraphDef(graphdef);
2326 }
2327 return graph_modified;
2328}
2329
2330tensorflow::Status ConvertTopKV2Operator(
2331 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2332 const ModelFlags& model_flags, Model* model) {
2333 CHECK((node.op() == "TopK") || (node.op() == "TopKV2"));
2334 auto op = std::make_unique<TopKV2Operator>();
2335 op->inputs.push_back(node.input(0));
2336 // K can be encoded as attr (TopK) convert it to a const.
2337 if (HasAttr(node, "k")) {
2338 std::string k_array = CreateConstArray<ArrayDataType::kInt32>(
2339 model, node.name() + "k", {static_cast<int32>(GetIntAttr(node, "k"))});
2340 op->inputs.push_back(k_array);
2341 } else {
2342 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2343 op->inputs.push_back(node.input(1));
2344 }
2345 // The op has two outputs.
2346 op->outputs.push_back(node.name());
2347 op->outputs.push_back(node.name() + ":1");
2348 model->operators.emplace_back(op.release());
2349 return ::tensorflow::OkStatus();
2350}
2351
2352tensorflow::Status ConvertDynamicPartitionOperator(
2353 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2354 const ModelFlags& model_flags, Model* model) {
2355 auto op = std::make_unique<DynamicPartitionOperator>();
2356 CHECK(HasAttr(node, "num_partitions"));
2357 op->num_partitions = GetIntAttr(node, "num_partitions");
2358 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2359 op->inputs.push_back(node.input(0));
2360 op->inputs.push_back(node.input(1));
2361 CHECK_GT(op->num_partitions, 1);
2362 op->outputs.push_back(node.name()); // Implicit :0.
2363 for (int i = 1; i < op->num_partitions; ++i) {
2364 op->outputs.push_back(node.name() + ":" + std::to_string(i));
2365 }
2366 model->operators.emplace_back(op.release());
2367 return ::tensorflow::OkStatus();
2368}
2369
2370tensorflow::Status ConvertDynamicStitchOperator(
2371 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2372 const ModelFlags& model_flags, Model* model) {
2373 // The parallel and non-parallel variants are the same besides whether they
2374 // have a parallel loop; there are no behavioral differences.
2375 CHECK(node.op() == "DynamicStitch" || node.op() == "ParallelDynamicStitch");
2376 auto op = std::make_unique<DynamicStitchOperator>();
2377 CHECK(HasAttr(node, "N"));
2378 op->num_partitions = GetIntAttr(node, "N");
2379 // Expect all ID partitions + all value partitions.
2380 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, op->num_partitions * 2));
2381 for (int i = 0; i < op->num_partitions * 2; ++i) {
2382 op->inputs.push_back(node.input(i));
2383 }
2384 op->outputs.push_back(node.name());
2385 model->operators.emplace_back(op.release());
2386 return ::tensorflow::OkStatus();
2387}
2388
2389tensorflow::Status ConvertSparseToDenseOperator(
2390 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2391 const ModelFlags& model_flags, Model* model) {
2392 CHECK_EQ(node.op(), "SparseToDense");
2393 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2394
2395 auto* op = new SparseToDenseOperator;
2396 for (const std::string& input : node.input()) {
2397 op->inputs.push_back(input);
2398 }
2399 op->outputs.push_back(node.name());
2400
2401 op->validate_indices = HasAttr(node, "validate_indices")
2402 ? GetBoolAttr(node, "validate_indices")
2403 : true;
2404 model->operators.emplace_back(op);
2405 return ::tensorflow::OkStatus();
2406}
2407
2408tensorflow::Status ConvertOneHotOperator(
2409 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2410 const ModelFlags& model_flags, Model* model) {
2411 CHECK_EQ(node.op(), "OneHot");
2412 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
2413
2414 const auto dtype = GetDataTypeAttr(node, "T");
2415 // TODO(b/111744875): Support DT_UINT8 and quantization.
2416 CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
2417 dtype == DT_BOOL);
2418
2419 auto op = std::make_unique<OneHotOperator>();
2420 op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
2421 for (const std::string& input : node.input()) {
2422 op->inputs.push_back(input);
2423 }
2424 op->outputs.push_back(node.name());
2425 model->operators.emplace_back(op.release());
2426 return ::tensorflow::OkStatus();
2427}
2428
2429tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
2430 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2431 const ModelFlags& model_flags, Model* model) {
2432 CHECK_EQ(node.op(), "CTCBeamSearchDecoder");
2433 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
2434
2435 auto* op = new CTCBeamSearchDecoderOperator;
2436 for (const std::string& input : node.input()) {
2437 op->inputs.push_back(input);
2438 }
2439
2440 op->beam_width =
2441 HasAttr(node, "beam_width") ? GetIntAttr(node, "beam_width") : 1;
2442 op->top_paths =
2443 HasAttr(node, "top_paths") ? GetIntAttr(node, "top_paths") : 1;
2444 op->merge_repeated = HasAttr(node, "merge_repeated")
2445 ? GetBoolAttr(node, "merge_repeated")
2446 : true;
2447
2448 // There are top_paths + 1 outputs.
2449 op->outputs.push_back(node.name()); // Implicit :0.
2450 for (int i = 0; i < op->top_paths; ++i) {
2451 op->outputs.push_back(node.name() + ":" + std::to_string(i + 1));
2452 }
2453 model->operators.emplace_back(op);
2454 return ::tensorflow::OkStatus();
2455}
2456
2457// This isn't a TensorFlow builtin op. Currently this node can only be generated
2458// with TfLite OpHint API.
2459tensorflow::Status ConvertUnidirectionalSequenceLstm(
2460 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2461 const ModelFlags& model_flags, Model* model) {
2462 DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
2463
2464 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2465
2466 auto* op = new UnidirectionalSequenceLstmOperator();
2467
2468 // The input size needs to be the same as the TfLite UniDirectionalSequence
2469 // Lstm implementation.
2470 const int kInputsSize = 20;
2471
2472 op->inputs.resize(kInputsSize);
2473
2474 if (indices.i_size() != node.input().size()) {
2475 // New version, the optional inputs are filled with constant nodes.
2476 int count = 0;
2477 for (int idx = 0; idx < kInputsSize; ++idx) {
2478 if (count < indices.i_size() && indices.i(count) == idx) {
2479 // Specified input.
2480 op->inputs[idx] = node.input(idx);
2481 count++;
2482 } else {
2483 // Optional input.
2484 std::string optional_name = node.name() + "_" + std::to_string(idx);
2485 model->CreateOptionalArray(optional_name);
2486 op->inputs[idx] = optional_name;
2487 }
2488 }
2489 } else { // Legacy version.
2490 std::vector<bool> done(kInputsSize);
2491 int idx = 0;
2492 for (const std::string& input : node.input()) {
2493 int real_index = indices.i(idx);
2494 op->inputs[real_index] = (input);
2495 done[real_index] = true;
2496 idx++;
2497 }
2498
2499 for (int idx = 0; idx < done.size(); idx++) {
2500 if (!done[idx]) {
2501 std::string optional_name = node.name() + "_" + std::to_string(idx);
2502 model->CreateOptionalArray(optional_name);
2503 op->inputs[idx] = optional_name;
2504 }
2505 }
2506 }
2507
2508 // There're three outputs, only the last one is required.
2509 op->outputs.push_back(node.name() + ":2");
2510 model->operators.emplace_back(op);
2511
2512 return ::tensorflow::OkStatus();
2513}
2514
2515tensorflow::Status ConvertLeakyReluOperator(
2516 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2517 const ModelFlags& model_flags, Model* model) {
2518 CHECK_EQ(node.op(), "LeakyRelu");
2519 TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
2520 CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT);
2521 const auto& input_name = node.input(0);
2522 auto* op = new LeakyReluOperator;
2523 op->inputs.push_back(input_name);
2524 op->outputs.push_back(node.name());
2525 op->alpha = GetFloatAttr(node, "alpha");
2526 model->operators.emplace_back(op);
2527 return ::tensorflow::OkStatus();
2528}
2529
2530tensorflow::Status ConvertUnidirectionalSequenceRnn(
2531 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2532 const ModelFlags& model_flags, Model* model) {
2533 DCHECK_EQ(node.op(), "UnidirectionalSequenceRnn");
2534
2535 const auto& indices = GetListAttr(node, "_tflite_input_indices");
2536 if (indices.i_size() != node.input().size()) {
2537 return tensorflow::errors::InvalidArgument("Input size does not match.");
2538 }
2539
2540 auto* op = new UnidirectionalSequenceRnnOperator();
2541 for (const std::string& input : node.input()) {
2542 op->inputs.push_back(input);
2543 }
2544 // Only use the last one as input.
2545 op->outputs.push_back(node.name() + ":1");
2546 model->operators.emplace_back(op);
2547
2548 return ::tensorflow::OkStatus();
2549}
2550
2551} // namespace
2552
2553namespace internal {
2554
2555using ConverterType = tensorflow::Status (*)(
2556 const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
2557 const ModelFlags& model_flags, Model* model);
2558using ConverterMapType = std::unordered_map<std::string, ConverterType>;
2559
2560ConverterMapType GetTensorFlowNodeConverterMapForFlex() {
2561 return std::unordered_map<std::string, ConverterType>({
2562 // We need to let TOCO convert Placeholder information into
2563 // array data, so that the data types are correct.
2564 {"LegacyFedInput", ConvertPlaceholderOperator},
2565 {"Placeholder", ConvertPlaceholderOperator},
2566 {"Const", ConditionallyConvertConstOperator},
2567 });
2568}
2569
2570ConverterMapType GetTensorFlowNodeConverterMap() {
2571 return std::unordered_map<std::string, ConverterType>({
2572 {"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
2573 {"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
2574 {"AddV2", ConvertSimpleOperator<AddOperator, 2, 1>},
2575 {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
2576 {"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
2577 {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
2578 {"ArgMax", ConvertArgMaxOperator},
2579 {"ArgMin", ConvertArgMinOperator},
2580 {"Assert",
2581 ConvertSimpleOperator<TensorFlowAssertOperator, kAnyNumInputs, 1>},
2582 {"AvgPool", ConvertAvgPoolOperator},
2583 {"BatchMatMul", ConvertBatchMatMulOperator},
2584 {"BatchMatMulV2", ConvertBatchMatMulOperator},
2585 {"BatchNormWithGlobalNormalization",
2586 ConvertBatchNormWithGlobalNormalizationOperator},
2587 {"BatchToSpaceND", ConvertBatchToSpaceNDOperator},
2588 {"BiasAdd", ConvertBiasAddOperator},
2589 {"Cast", ConvertCastOperator},
2590 {"Ceil", ConvertCeilOperator},
2591 {"CheckNumerics", ConvertIdentityOperator},
2592 {"Concat", ConvertConcatOperator},
2593 {"ConcatV2", ConvertConcatOperator},
2594 {"Const", ConvertConstOperator},
2595 {"Conv2D", ConvertConvOperator},
2596 {"Conv2DBackpropInput", ConvertTransposeConvOperator},
2597 {"Cos", ConvertSimpleOperator<CosOperator, 1, 1>},
2598 {"CTCBeamSearchDecoder", ConvertCTCBeamSearchDecoderOperator},
2599 {"DepthToSpace", ConvertDepthToSpaceOperator},
2600 {"DepthwiseConv2dNative", ConvertDepthwiseConvOperator},
2601 {"Div", ConvertSimpleOperator<DivOperator, 2, 1>},
2602 {"DynamicPartition", ConvertDynamicPartitionOperator},
2603 {"DynamicStitch", ConvertDynamicStitchOperator},
2604 {"Elu", ConvertSimpleOperator<EluOperator, 1, 1>},
2605 {"EnsureShape", ConvertIdentityOperator},
2606 {"Equal", ConvertSimpleOperator<TensorFlowEqualOperator, 2, 1>},
2607 {"Exp", ConvertSimpleOperator<ExpOperator, 1, 1>},
2608 {"ExpandDims", ConvertSimpleOperator<ExpandDimsOperator, 2, 1>},
2609 {"FakeQuantWithMinMaxArgs", ConvertFakeQuantWithMinMaxArgs},
2610 {"FakeQuantWithMinMaxVars", ConvertFakeQuantWithMinMaxVars},
2611 {"Fill", ConvertSimpleOperator<FillOperator, 2, 1>},
2612 {"Floor", ConvertFloorOperator},
2613 {"FloorDiv", ConvertSimpleOperator<FloorDivOperator, 2, 1>},
2614 {"FloorMod", ConvertSimpleOperator<FloorModOperator, 2, 1>},
2615 {"FusedBatchNorm", ConvertFusedBatchNormOperator},
2616 {"FusedBatchNormV3", ConvertFusedBatchNormOperator},
2617 {"Gather", ConvertGatherOperator},
2618 {"GatherV2", ConvertGatherOperator},
2619 {"GatherNd", ConvertGatherNdOperator},
2620 {"Greater", ConvertSimpleOperator<TensorFlowGreaterOperator, 2, 1>},
2621 {"GreaterEqual",
2622 ConvertSimpleOperator<TensorFlowGreaterEqualOperator, 2, 1>},
2623 {"Identity", ConvertIdentityOperator},
2624 {"IdentityN", ConvertIdentityNOperator},
2625 {"LRN", ConvertLRNOperator},
2626 {"LeakyRelu", ConvertLeakyReluOperator},
2627 {"LegacyFedInput", ConvertPlaceholderOperator},
2628 {"Less", ConvertSimpleOperator<TensorFlowLessOperator, 2, 1>},
2629 {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2, 1>},
2630 {"Log", ConvertSimpleOperator<LogOperator, 1, 1>},
2631 {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2, 1>},
2632 {"LogicalOr", ConvertSimpleOperator<LogicalOrOperator, 2, 1>},
2633 {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1, 1>},
2634 {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1, 1>},
2635 {"MatMul", ConvertMatMulOperator},
2636 {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
2637 {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
2638 // `MatrixDiagV3` has an `align` attribute. However, Toco only converts
2639 // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
2640 // padding_value` inputs. In this case, `align` can be ignored.
2641 {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
2642 {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
2643 {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
2644 // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
2645 // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
2646 // case, `align` can be ignored.
2647 {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
2648 {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
2649 {"MaxPool", ConvertMaxPoolOperator},
2650 {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
2651 {"Mean", ConvertReduceOperator<MeanOperator>},
2652 {"Merge",
2653 ConvertSimpleOperator<TensorFlowMergeOperator, kAnyNumInputs, 1>},
2654 {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
2655 {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2, 1>},
2656 {"Mul", ConvertSimpleOperator<MulOperator, 2, 1>},
2657 {"Neg", ConvertSimpleOperator<NegOperator, 1, 1>},
2658 {"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
2659 {"NoOp", ConvertNoOpOperator},
2660 {"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2, 1>},
2661 {"OneHot", ConvertOneHotOperator},
2662 {"Pack", ConvertPackOperator},
2663 {"Pad", ConvertSimpleOperator<PadOperator, 2, 1>},
2664 {"PadV2", ConvertSimpleOperator<PadV2Operator, 3, 1>},
2665 {"ParallelDynamicStitch", ConvertDynamicStitchOperator},
2666 {"Placeholder", ConvertPlaceholderOperator},
2667 {"PlaceholderWithDefault", ConvertIdentityOperator},
2668 {"Pow", ConvertSimpleOperator<PowOperator, 2, 1>},
2669 {"Prod", ConvertReduceOperator<TensorFlowProdOperator>},
2670 {"RandomUniform", ConvertRandomUniform},
2671 {"Range", ConvertRangeOperator},
2672 {"Rank", ConvertSimpleOperator<TensorFlowRankOperator, 1, 1>},
2673 {"RealDiv", ConvertSimpleOperator<DivOperator, 2, 1>},
2674 {"Relu", ConvertSimpleOperator<ReluOperator, 1, 1>},
2675 {"Relu6", ConvertSimpleOperator<Relu6Operator, 1, 1>},
2676 {"Reshape", ConvertSimpleOperator<TensorFlowReshapeOperator, 2, 1>},
2677 {"ResizeBilinear", ConvertResizeBilinearOperator},
2678 {"ResizeNearestNeighbor", ConvertResizeNearestNeighborOperator},
2679 {"ReverseSequence", ConvertReverseSequenceOperator},
2680 {"ReverseV2", ConvertSimpleOperator<ReverseV2Operator, 2, 1>},
2681 {"Round", ConvertRoundOperator},
2682 {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1, 1>},
2683 {"ScatterNd", ConvertSimpleOperator<ScatterNdOperator, 3, 1>},
2684 {"SegmentSum", ConvertSimpleOperator<SegmentSumOperator, 2, 1>},
2685 {"Select", ConvertSimpleOperator<SelectOperator, 3, 1>},
2686 {"SelectV2", ConvertSimpleOperator<SelectOperator, 3, 1>},
2687 {"Shape", ConvertShapeOperator},
2688 {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1, 1>},
2689 {"Sin", ConvertSimpleOperator<SinOperator, 1, 1>},
2690 {"Slice", ConvertSimpleOperator<SliceOperator, 3, 1>},
2691 {"Softmax", ConvertSoftmaxOperator},
2692 {"SpaceToBatchND", ConvertSpaceToBatchNDOperator},
2693 {"SpaceToDepth", ConvertSpaceToDepthOperator},
2694 {"SparseToDense", ConvertSparseToDenseOperator},
2695 {"Split", ConvertSplitOperator},
2696 {"SplitV", ConvertSplitVOperator},
2697 {"Sqrt", ConvertSimpleOperator<TensorFlowSqrtOperator, 1, 1>},
2698 {"Square", ConvertSimpleOperator<TensorFlowSquareOperator, 1, 1>},
2699 {"SquaredDifference",
2700 ConvertSimpleOperator<SquaredDifferenceOperator, 2, 1>},
2701 {"Snapshot", ConvertIdentityOperator},
2702 {"Squeeze", ConvertSqueezeOperator},
2703 {"StopGradient", ConvertIdentityOperator},
2704 {"StridedSlice", ConvertStridedSliceOperator},
2705 {"Sub", ConvertSimpleOperator<SubOperator, 2, 1>},
2706 {"Sum", ConvertReduceOperator<TensorFlowSumOperator>},
2707 {"Svdf", ConvertSvdfOperator},
2708 {"Switch", ConvertSwitchOperator},
2709 {"Tanh", ConvertSimpleOperator<TanhOperator, 1, 1>},
2710 {"Tile", ConvertSimpleOperator<TensorFlowTileOperator, 2, 1>},
2711 {"TopK", ConvertTopKV2Operator},
2712 {"TopKV2", ConvertTopKV2Operator},
2713 {"Transpose", ConvertSimpleOperator<TransposeOperator, 2, 1>},
2714 {"Unpack", ConvertUnpackOperator},
2715 {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1, 1>},
2716 {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
2717 {"UnidirectionalSequenceRnn", ConvertUnidirectionalSequenceRnn},
2718 {"MirrorPad", ConvertMirrorPadOperator},
2719 {"Unique", ConvertSimpleOperator<UniqueOperator, 1, 2>},
2720 {"Where", ConvertSimpleOperator<WhereOperator, 1, 1>},
2721 });
2722}
2723
2724tensorflow::Status ImportTensorFlowNode(
2725 const tensorflow::NodeDef& node,
2726 const TensorFlowImportFlags& tf_import_flags, const ModelFlags& model_flags,
2727 Model* model, const ConverterMapType& converter_map) {
2728 auto converter = converter_map.find(node.op());
2729 if (converter == converter_map.end()) {
2730 return ConvertUnsupportedOperator(node, tf_import_flags, model_flags,
2731 model);
2732 } else {
2733 return converter->second(node, tf_import_flags, model_flags, model);
2734 }
2735}
2736} // namespace internal
2737
2738std::unique_ptr<Model> ImportTensorFlowGraphDef(
2739 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2740 const GraphDef& tf_graph) {
2741 LogDumpGraphDef(kLogLevelModelChanged, "AT IMPORT", tf_graph);
2742
2743 GraphDef inlined_graph(tf_graph);
2744 if (InlineAllFunctions(&inlined_graph)) {
2745 LogDumpGraphDef(kLogLevelModelChanged, "AFTER INLINING", inlined_graph);
2746 }
2747
2748 // Check input and output specification.
2749 for (const auto& specified_input_array : model_flags.input_arrays()) {
2750 CHECK(!absl::EndsWith(specified_input_array.name(), ":0"))
2751 << "Unsupported explicit zero output index: "
2752 << specified_input_array.name();
2753 }
2754 for (const std::string& specified_output_array :
2755 model_flags.output_arrays()) {
2756 CHECK(!absl::EndsWith(specified_output_array, ":0"))
2757 << "Unsupported explicit zero output index: " << specified_output_array;
2758 }
2759
2760 Model* model = new Model;
2761 internal::ConverterMapType converter_map;
2762
2763 // This is used for the TFLite "Full Flex Mode" conversion. All the ops are
2764 // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
2765 // converted to TFLite Flex ops.
2766 if (!tf_import_flags.import_all_ops_as_unsupported) {
2767 converter_map = internal::GetTensorFlowNodeConverterMap();
2768 } else {
2769 converter_map = internal::GetTensorFlowNodeConverterMapForFlex();
2770 }
2771
2772 for (auto node : inlined_graph.node()) {
2773 StripZeroOutputIndexFromInputs(&node);
2774 auto status = internal::ImportTensorFlowNode(
2775 node, tf_import_flags, model_flags, model, converter_map);
2776 CHECK(status.ok()) << status.error_message();
2777 }
2778
2779 ResolveModelFlags(model_flags, model);
2780
2781 StripCaretFromArrayNames(model);
2782 AddExtraOutputs(model);
2783 FixNoMissingArray(model);
2784 FixNoOrphanedArray(model);
2785 FixOperatorOrdering(model);
2786 CheckInvariants(*model);
2787
2788 // if rnn state arrays are constant, make them transient
2789 for (const auto& rnn_state : model->flags.rnn_states()) {
2790 model->GetArray(rnn_state.state_array()).buffer = nullptr;
2791 }
2792
2793 return std::unique_ptr<Model>(model);
2794}
2795
2796std::unique_ptr<Model> ImportTensorFlowGraphDef(
2797 const ModelFlags& model_flags, const TensorFlowImportFlags& tf_import_flags,
2798 const std::string& input_file_contents) {
2799 std::unique_ptr<GraphDef> tf_graph(new GraphDef);
2800 CHECK(ParseFromStringEitherTextOrBinary(input_file_contents, tf_graph.get()));
2801
2802 std::unique_ptr<GraphDef> pruned_graph =
2803 MaybeReplaceCompositeSubgraph(*tf_graph);
2804 if (pruned_graph) {
2805 tf_graph = std::move(pruned_graph);
2806 }
2807 return ImportTensorFlowGraphDef(model_flags, tf_import_flags, *tf_graph);
2808}
2809
2810} // namespace toco
2811