1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <memory>
17#include <string>
18#include <unordered_map>
19#include <utility>
20#include <vector>
21
22#include "google/protobuf/map.h"
23#include "google/protobuf/text_format.h"
24#include "absl/memory/memory.h"
25#include "absl/strings/string_view.h"
26#include "tensorflow/core/framework/attr_value.pb.h"
27#include "tensorflow/core/framework/graph.pb.h"
28#include "tensorflow/core/framework/node_def.pb.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/framework/tensor_shape.pb.h"
31#include "tensorflow/core/framework/types.pb.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/lite/toco/model.h"
34#include "tensorflow/lite/toco/model_flags.pb.h"
35#include "tensorflow/lite/toco/runtime/types.h"
36#include "tensorflow/lite/toco/tensorflow_util.h"
37#include "tensorflow/lite/toco/tooling_util.h"
38
39using tensorflow::DT_BOOL;
40using tensorflow::DT_COMPLEX64;
41using tensorflow::DT_FLOAT;
42using tensorflow::DT_INT16;
43using tensorflow::DT_INT32;
44using tensorflow::DT_INT64;
45using tensorflow::DT_UINT32;
46using tensorflow::DT_UINT8;
47using tensorflow::GraphDef;
48using tensorflow::TensorProto;
49
50namespace toco {
51namespace {
52
53tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
54 const std::string& error_location) {
55 switch (data_type) {
56 case ArrayDataType::kBool:
57 return tensorflow::DT_BOOL;
58 case ArrayDataType::kFloat:
59 return tensorflow::DT_FLOAT;
60 case ArrayDataType::kUint8:
61 return tensorflow::DT_UINT8;
62 case ArrayDataType::kInt16:
63 return tensorflow::DT_INT16;
64 case ArrayDataType::kUint16:
65 return tensorflow::DT_UINT16;
66 case ArrayDataType::kInt32:
67 return tensorflow::DT_INT32;
68 case ArrayDataType::kUint32:
69 return tensorflow::DT_UINT32;
70 case ArrayDataType::kInt64:
71 return tensorflow::DT_INT64;
72 case ArrayDataType::kString:
73 return tensorflow::DT_STRING;
74 case ArrayDataType::kComplex64:
75 return tensorflow::DT_COMPLEX64;
76 default:
77 case ArrayDataType::kNone:
78 LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
79 << "' in " << error_location;
80 return tensorflow::DT_INVALID;
81 }
82}
83
84tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
85 const std::string& op_name) {
86 return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
87}
88
89tensorflow::DataType GetTensorFlowDataType(const Model& model,
90 const std::string& array_name) {
91 return GetTensorFlowDataType(model.GetArray(array_name).data_type,
92 "array '" + array_name + "'");
93}
94
95// TensorFlow sometimes forbids what it calls "legacy scalars",
96// which are 1-D shapes where the unique shape size is 1.
97// See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
98// For that reason, we generally avoid creating legacy scalars,
99// by detecting the case where a 1-D shape would be of size 1 and
100// replacing that by a 0-D shape.
101// However, there is a special circumstance where we must not do that
102// and must unconditionally create a 1-D shape even if it is going to
103// be of size 1: that is the case of bias vectors, with BiasAdd nodes.
104// Indeed, TensorFlow requires bias vectors to be 1-D; in the case of
105// a depth of 1, that would be a legacy scalar, so in that case we
106// must go ahead and keep the shape 1-D, letting it be a legacy scalar.
107enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars };
108
109void ExportFloatArray(const Shape& input_shape, const float* input_data,
110 TensorProto* output_tensor,
111 LegacyScalarPolicy legacy_scalar_policy) {
112 output_tensor->set_dtype(DT_FLOAT);
113 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
114 auto* shape = output_tensor->mutable_tensor_shape();
115
116 const int kDims = input_shape.dimensions_count();
117 if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars ||
118 kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) {
119 for (int i = 0; i < kDims; ++i) {
120 shape->add_dim()->set_size(input_shape.dims(i));
121 }
122 }
123 output_tensor->set_tensor_content(
124 std::string(reinterpret_cast<const char*>(input_data),
125 sizeof(*input_data) * input_flat_size));
126}
127
128void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape,
129 const float* input_data, AxesOrder output_axes_order,
130 TensorProto* output_tensor,
131 LegacyScalarPolicy legacy_scalar_policy) {
132 CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order));
133 output_tensor->set_dtype(DT_FLOAT);
134 CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order));
135 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
136
137 Shape shuffled_shape;
138 ShuffleDims(input_shape, input_axes_order, output_axes_order,
139 &shuffled_shape);
140 std::vector<float> shuffled_data(input_flat_size);
141 ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape,
142 input_data, shuffled_data.data());
143
144 ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor,
145 legacy_scalar_policy);
146}
147
148bool HasAlreadyExportedConst(const std::string& name,
149 const GraphDef& tensorflow_graph) {
150 for (const auto& node : tensorflow_graph.node()) {
151 if (node.op() == "Const" && node.name() == name) {
152 return true;
153 }
154 }
155 return false;
156}
157
158void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
159 const float* input_data,
160 AxesOrder input_axes_order,
161 AxesOrder output_axes_order,
162 GraphDef* tensorflow_graph,
163 LegacyScalarPolicy legacy_scalar_policy) {
164 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
165 return;
166 }
167 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
168 const_op->set_op("Const");
169 const_op->set_name(name);
170 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
171 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
172 ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order,
173 tensor, legacy_scalar_policy);
174}
175
176void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape,
177 const float* input_data,
178 AxesOrder input_axes_order,
179 AxesOrder output_axes_order,
180 GraphDef* tensorflow_graph) {
181 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
182 output_axes_order, tensorflow_graph,
183 LegacyScalarPolicy::kAvoidLegacyScalars);
184}
185
186void ConvertFloatTensorConst(const Model& model, const std::string& name,
187 AxesOrder input_axes_order,
188 AxesOrder output_axes_order,
189 GraphDef* tensorflow_graph) {
190 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
191 return;
192 }
193 CHECK(model.HasArray(name));
194 const auto& input_array = model.GetArray(name);
195 const auto& input_shape = input_array.shape();
196 CHECK(input_array.buffer);
197 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
198 const float* input_data =
199 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
200 ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order,
201 output_axes_order, tensorflow_graph);
202}
203
204void ConvertFloatTensorConst(const Model& model, const std::string& name,
205 GraphDef* tensorflow_graph) {
206 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
207 return;
208 }
209 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
210 const_op->set_op("Const");
211 const_op->set_name(name);
212 (*const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
213 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
214 CHECK(model.HasArray(name));
215 const auto& input_array = model.GetArray(name);
216 const auto& input_shape = input_array.shape();
217 CHECK(input_array.buffer);
218 CHECK(input_array.buffer->type == ArrayDataType::kFloat);
219 const float* input_data =
220 input_array.GetBuffer<ArrayDataType::kFloat>().data.data();
221 ExportFloatArray(input_shape, input_data, tensor,
222 LegacyScalarPolicy::kAvoidLegacyScalars);
223}
224
225void ConvertBoolTensorConst(const Model& model, const std::string& name,
226 GraphDef* tensorflow_graph) {
227 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
228 return;
229 }
230 CHECK(model.HasArray(name));
231 const auto& array = model.GetArray(name);
232 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
233 const_op->set_op("Const");
234 const_op->set_name(name);
235 (*const_op->mutable_attr())["dtype"].set_type(DT_BOOL);
236 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
237 tensor->set_dtype(DT_BOOL);
238 const auto& data = array.GetBuffer<ArrayDataType::kBool>().data;
239 for (auto index : data) {
240 tensor->add_bool_val(index);
241 }
242 const auto& array_shape = array.shape();
243 auto* shape = tensor->mutable_tensor_shape();
244 for (int i = 0; i < array_shape.dimensions_count(); i++) {
245 shape->add_dim()->set_size(array_shape.dims(i));
246 }
247}
248
249void ConvertIntTensorConst(const Model& model, const std::string& name,
250 GraphDef* tensorflow_graph) {
251 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
252 return;
253 }
254 CHECK(model.HasArray(name));
255 const auto& array = model.GetArray(name);
256 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
257 const_op->set_op("Const");
258 const_op->set_name(name);
259 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
260 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
261 tensor->set_dtype(DT_INT32);
262 const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
263 for (auto index : data) {
264 tensor->add_int_val(index);
265 }
266 const auto& array_shape = array.shape();
267 auto* shape = tensor->mutable_tensor_shape();
268 for (int i = 0; i < array_shape.dimensions_count(); i++) {
269 shape->add_dim()->set_size(array_shape.dims(i));
270 }
271}
272
273void CreateIntTensorConst(const std::string& name,
274 const std::vector<int32>& data,
275 const std::vector<int32>& shape,
276 GraphDef* tensorflow_graph) {
277 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
278 return;
279 }
280 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
281 const_op->set_op("Const");
282 const_op->set_name(name);
283 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
284 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
285 tensor->set_dtype(DT_INT32);
286 for (auto index : data) {
287 tensor->add_int_val(index);
288 }
289 auto* tensor_shape = tensor->mutable_tensor_shape();
290 int num_elements = 1;
291 for (int size : shape) {
292 tensor_shape->add_dim()->set_size(size);
293 num_elements *= size;
294 }
295 CHECK_EQ(num_elements, data.size());
296}
297
298void ConvertComplex64TensorConst(const Model& model, const std::string& name,
299 GraphDef* tensorflow_graph) {
300 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
301 return;
302 }
303 CHECK(model.HasArray(name));
304 const auto& array = model.GetArray(name);
305 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
306 const_op->set_op("Const");
307 const_op->set_name(name);
308 (*const_op->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
309 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
310 tensor->set_dtype(DT_COMPLEX64);
311 const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data;
312 for (auto index : data) {
313 tensor->add_scomplex_val(std::real(index));
314 tensor->add_scomplex_val(std::imag(index));
315 }
316 const auto& array_shape = array.shape();
317 auto* shape = tensor->mutable_tensor_shape();
318 for (int i = 0; i < array_shape.dimensions_count(); i++) {
319 shape->add_dim()->set_size(array_shape.dims(i));
320 }
321}
322
323void CreateMatrixShapeTensorConst(const std::string& name, int rows, int cols,
324 GraphDef* tensorflow_graph) {
325 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
326 return;
327 }
328 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
329 const_op->set_op("Const");
330 const_op->set_name(name);
331 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
332 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
333 tensor->set_dtype(DT_INT32);
334 const int32 data[2] = {cols, rows};
335 tensor->set_tensor_content(
336 std::string(reinterpret_cast<const char*>(data), sizeof(data)));
337 auto* shape = tensor->mutable_tensor_shape();
338 shape->add_dim()->set_size(2);
339}
340
341void CreateDummyConcatDimTensorConst(const std::string& name, int dim,
342 GraphDef* tensorflow_graph) {
343 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
344 return;
345 }
346 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
347 const_op->set_op("Const");
348 const_op->set_name(name);
349 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
350 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
351 tensor->set_dtype(DT_INT32);
352 tensor->add_int_val(dim);
353}
354
355void CreateReshapeShapeTensorConst(const std::string& name,
356 const std::vector<int32>& shape,
357 GraphDef* tensorflow_graph) {
358 if (HasAlreadyExportedConst(name, *tensorflow_graph)) {
359 return;
360 }
361 tensorflow::NodeDef* const_op = tensorflow_graph->add_node();
362 const_op->set_op("Const");
363 const_op->set_name(name);
364 (*const_op->mutable_attr())["dtype"].set_type(DT_INT32);
365 auto* tensor = (*const_op->mutable_attr())["value"].mutable_tensor();
366 tensor->set_dtype(DT_INT32);
367 for (auto s : shape) {
368 tensor->add_int_val(s);
369 }
370 // TensorFlow sometimes forbids what it calls "legacy scalars",
371 // which are shapes of size 1 where the unique shape size is 1.
372 // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars.
373 if (shape.size() > 1) {
374 auto* tensor_shape = tensor->mutable_tensor_shape();
375 tensor_shape->add_dim()->set_size(shape.size());
376 }
377}
378
379std::string WalkUpToConstantArray(const Model& model, const std::string& name) {
380 const Array& original_array = model.GetArray(name);
381 if (original_array.buffer) {
382 return name;
383 }
384 const auto* op = GetOpWithOutput(model, name);
385 CHECK(op);
386 CHECK(op->type == OperatorType::kFakeQuant);
387 const std::string& input_of_fakequant_name = op->inputs[0];
388 const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name);
389 CHECK(input_of_fakequant.buffer);
390 return input_of_fakequant_name;
391}
392
393void ConvertConvOperator(const Model& model, const ConvOperator& src_op,
394 GraphDef* tensorflow_graph) {
395 const bool has_bias = src_op.inputs.size() >= 3;
396 std::string conv_output = src_op.outputs[0];
397 if (has_bias) {
398 conv_output += "/conv";
399 }
400
401 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
402 conv2d_op->set_op("Conv2D");
403 conv2d_op->set_name(conv_output);
404 *conv2d_op->add_input() = src_op.inputs[0];
405 *conv2d_op->add_input() = src_op.inputs[1];
406 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
407 const std::string& weights_array_name =
408 WalkUpToConstantArray(model, src_op.inputs[1]);
409 const auto& weights_array = model.GetArray(weights_array_name);
410 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
411 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
412 AxesOrder::kHWIO, tensorflow_graph);
413 auto& strides = (*conv2d_op->mutable_attr())["strides"];
414 strides.mutable_list()->add_i(1);
415 strides.mutable_list()->add_i(src_op.stride_height);
416 strides.mutable_list()->add_i(src_op.stride_width);
417 strides.mutable_list()->add_i(1);
418 if ((src_op.dilation_width_factor != 1) ||
419 (src_op.dilation_height_factor != 1)) {
420 auto& dilations = (*conv2d_op->mutable_attr())["dilations"];
421 dilations.mutable_list()->add_i(1);
422 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
423 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
424 dilations.mutable_list()->add_i(1);
425 }
426 std::string padding;
427 if (src_op.padding.type == PaddingType::kSame) {
428 padding = "SAME";
429 } else if (src_op.padding.type == PaddingType::kValid) {
430 padding = "VALID";
431 } else {
432 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
433 }
434 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
435
436 if (has_bias) {
437 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
438 biasadd_op->set_op("BiasAdd");
439 biasadd_op->set_name(src_op.outputs[0]);
440 biasadd_op->add_input(conv_output);
441 biasadd_op->add_input(src_op.inputs[2]);
442 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
443 CHECK(model.HasArray(src_op.inputs[2]));
444 const std::string& bias_array_name =
445 WalkUpToConstantArray(model, src_op.inputs[2]);
446 const auto& bias_array = model.GetArray(bias_array_name);
447 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
448 Shape bias_shape_1d = bias_array.shape();
449 UnextendShape(&bias_shape_1d, 1);
450 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
451 const float* bias_data =
452 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
453 ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data,
454 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
455 tensorflow_graph,
456 LegacyScalarPolicy::kDoCreateLegacyScalars);
457 }
458}
459
460void ConvertDepthwiseConvOperator(const Model& model,
461 const DepthwiseConvOperator& src_op,
462 GraphDef* tensorflow_graph) {
463 const bool has_bias = src_op.inputs.size() >= 3;
464 std::string conv_output = src_op.outputs[0];
465 if (has_bias) {
466 conv_output += "/conv";
467 }
468
469 tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node();
470 dc2d_op->set_op("DepthwiseConv2dNative");
471 dc2d_op->set_name(conv_output);
472 *dc2d_op->add_input() = src_op.inputs[0];
473 *dc2d_op->add_input() = src_op.inputs[1];
474 (*dc2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
475
476 // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth.
477 // We need to convert that to H x W x InputDepth x Multiplier.
478 // That's only a matter of constructing a Dims object; the actual
479 // array layout is the same.
480 CHECK(model.HasArray(src_op.inputs[1]));
481 const std::string& src_weights_name =
482 WalkUpToConstantArray(model, src_op.inputs[1]);
483 const auto& src_weights_array = model.GetArray(src_weights_name);
484 const auto& src_weights_shape = src_weights_array.shape();
485 CHECK_EQ(src_weights_shape.dimensions_count(), 4);
486 const Shape dst_weights_shape =
487 Shape({src_weights_shape.dims(1), src_weights_shape.dims(2),
488 src_weights_shape.dims(3) / src_op.depth_multiplier,
489 src_op.depth_multiplier});
490 CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0);
491 CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) ==
492 src_weights_shape.dims(3));
493 CHECK_EQ(src_weights_shape.dims(0), 1);
494
495 CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat);
496 const float* src_weights_data =
497 src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
498 ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data,
499 AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph);
500
501 auto& strides = (*dc2d_op->mutable_attr())["strides"];
502 strides.mutable_list()->add_i(1);
503 strides.mutable_list()->add_i(src_op.stride_height);
504 strides.mutable_list()->add_i(src_op.stride_width);
505 strides.mutable_list()->add_i(1);
506 // TODO(b/116063589): To return a working TF GraphDef, we should be returning
507 // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
508 // the conv since TF doesn't support dilations.
509 if ((src_op.dilation_width_factor != 1) ||
510 (src_op.dilation_height_factor != 1)) {
511 auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
512 dilations.mutable_list()->add_i(1);
513 dilations.mutable_list()->add_i(src_op.dilation_height_factor);
514 dilations.mutable_list()->add_i(src_op.dilation_width_factor);
515 dilations.mutable_list()->add_i(1);
516 }
517 std::string padding;
518 if (src_op.padding.type == PaddingType::kSame) {
519 padding = "SAME";
520 } else if (src_op.padding.type == PaddingType::kValid) {
521 padding = "VALID";
522 } else {
523 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
524 }
525 (*dc2d_op->mutable_attr())["padding"].set_s(padding);
526
527 if (has_bias) {
528 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
529 biasadd_op->set_op("BiasAdd");
530 biasadd_op->set_name(src_op.outputs[0]);
531 biasadd_op->add_input(conv_output);
532 biasadd_op->add_input(src_op.inputs[2]);
533 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
534 CHECK(model.HasArray(src_op.inputs[2]));
535 const std::string& bias_name =
536 WalkUpToConstantArray(model, src_op.inputs[2]);
537 const auto& bias_array = model.GetArray(bias_name);
538 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
539 Shape bias_shape_1d = bias_array.shape();
540 UnextendShape(&bias_shape_1d, 1);
541 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
542 const float* bias_data =
543 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
544 ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data,
545 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
546 tensorflow_graph,
547 LegacyScalarPolicy::kDoCreateLegacyScalars);
548 }
549}
550
551void ConvertTransposeConvOperator(const Model& model,
552 const TransposeConvOperator& src_op,
553 GraphDef* tensorflow_graph) {
554 tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node();
555 conv2d_op->set_op("Conv2DBackpropInput");
556 conv2d_op->set_name(src_op.outputs[0]);
557 *conv2d_op->add_input() = src_op.inputs[0];
558 *conv2d_op->add_input() = src_op.inputs[1];
559 *conv2d_op->add_input() = src_op.inputs[2];
560 (*conv2d_op->mutable_attr())["T"].set_type(DT_FLOAT);
561 const std::string& weights_array_name = WalkUpToConstantArray(
562 model, src_op.inputs[TransposeConvOperator::WEIGHTS]);
563 const auto& weights_array = model.GetArray(weights_array_name);
564 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
565 ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI,
566 AxesOrder::kHWOI, tensorflow_graph);
567 auto& strides = (*conv2d_op->mutable_attr())["strides"];
568 strides.mutable_list()->add_i(1);
569 strides.mutable_list()->add_i(src_op.stride_height);
570 strides.mutable_list()->add_i(src_op.stride_width);
571 strides.mutable_list()->add_i(1);
572 std::string padding;
573 if (src_op.padding.type == PaddingType::kSame) {
574 padding = "SAME";
575 } else if (src_op.padding.type == PaddingType::kValid) {
576 padding = "VALID";
577 } else {
578 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
579 }
580 (*conv2d_op->mutable_attr())["padding"].set_s(padding);
581}
582
583void ConvertDepthToSpaceOperator(const Model& model,
584 const DepthToSpaceOperator& src_op,
585 GraphDef* tensorflow_graph) {
586 tensorflow::NodeDef* op = tensorflow_graph->add_node();
587 op->set_op("DepthToSpace");
588 op->set_name(src_op.outputs[0]);
589 *op->add_input() = src_op.inputs[0];
590 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
591 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
592}
593
594void ConvertSpaceToDepthOperator(const Model& model,
595 const SpaceToDepthOperator& src_op,
596 GraphDef* tensorflow_graph) {
597 tensorflow::NodeDef* op = tensorflow_graph->add_node();
598 op->set_op("SpaceToDepth");
599 op->set_name(src_op.outputs[0]);
600 *op->add_input() = src_op.inputs[0];
601 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
602 (*op->mutable_attr())["block_size"].set_i(src_op.block_size);
603}
604
605void ConvertFullyConnectedOperator(const Model& model,
606 const FullyConnectedOperator& src_op,
607 GraphDef* tensorflow_graph) {
608 // Reshape input activations to have the shape expected by the MatMul.
609 const std::string reshape_output =
610 AvailableArrayName(model, src_op.outputs[0] + "/reshape");
611 const std::string reshape_shape =
612 AvailableArrayName(model, reshape_output + "/shape");
613 const auto& fc_weights_array = model.GetArray(src_op.inputs[1]);
614 const auto& fc_weights_shape = fc_weights_array.shape();
615 CHECK_EQ(fc_weights_shape.dimensions_count(), 2);
616 CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1,
617 tensorflow_graph);
618 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
619 reshape_op->set_op("Reshape");
620 reshape_op->set_name(reshape_output);
621 reshape_op->add_input(src_op.inputs[0]);
622 reshape_op->add_input(reshape_shape);
623 (*reshape_op->mutable_attr())["T"].set_type(
624 GetTensorFlowDataType(model, src_op.inputs[0]));
625
626 const bool has_bias = src_op.inputs.size() >= 3;
627 std::string matmul_output = src_op.outputs[0];
628 if (has_bias) {
629 matmul_output += "/matmul";
630 }
631
632 // Transpose the RHS input from column-major to row-major to match TensorFlow
633 // expectations. This is the inverse of the transpose we do during
634 // ResolveTensorFlowMatMul.
635 const std::string transpose_output =
636 AvailableArrayName(model, matmul_output + "/transpose_weights");
637 const std::string transpose_perm =
638 AvailableArrayName(model, transpose_output + "/perm");
639 CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph);
640 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
641 transpose_op->set_op("Transpose");
642 transpose_op->set_name(transpose_output);
643 *transpose_op->add_input() = src_op.inputs[1];
644 *transpose_op->add_input() = transpose_perm;
645 (*transpose_op->mutable_attr())["T"].set_type(
646 GetTensorFlowDataType(model, src_op.inputs[1]));
647 (*transpose_op->mutable_attr())["Tperm"].set_type(DT_INT32);
648
649 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
650 matmul_op->set_op("MatMul");
651 matmul_op->set_name(matmul_output);
652 *matmul_op->add_input() = reshape_output;
653 *matmul_op->add_input() = transpose_op->name();
654 (*matmul_op->mutable_attr())["T"].set_type(
655 GetTensorFlowDataType(model, src_op.inputs[0]));
656 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
657 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
658 CHECK(model.HasArray(src_op.inputs[1]));
659
660 // Add the bias, if it exists.
661 if (has_bias) {
662 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
663 biasadd_op->set_op("BiasAdd");
664 biasadd_op->set_name(src_op.outputs[0]);
665 biasadd_op->add_input(matmul_output);
666 biasadd_op->add_input(src_op.inputs[2]);
667 (*biasadd_op->mutable_attr())["T"].set_type(
668 GetTensorFlowDataType(model, src_op.inputs[0]));
669 CHECK(model.HasArray(src_op.inputs[2]));
670 const auto& bias_array = model.GetArray(src_op.inputs[2]);
671 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
672 Shape bias_shape_1d = bias_array.shape();
673 UnextendShape(&bias_shape_1d, 1);
674 CHECK(bias_array.buffer);
675 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
676 const float* bias_data =
677 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
678 ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]),
679 bias_shape_1d, bias_data, AxesOrder::kOneAxis,
680 AxesOrder::kOneAxis, tensorflow_graph,
681 LegacyScalarPolicy::kDoCreateLegacyScalars);
682 }
683}
684
685void ConvertAddOperator(const Model& model, const AddOperator& src_op,
686 GraphDef* tensorflow_graph) {
687 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
688 add_op->set_op("Add");
689 add_op->set_name(src_op.outputs[0]);
690 CHECK_EQ(src_op.inputs.size(), 2);
691 *add_op->add_input() = src_op.inputs[0];
692 *add_op->add_input() = src_op.inputs[1];
693 (*add_op->mutable_attr())["T"].set_type(
694 GetTensorFlowDataType(model, src_op.outputs[0]));
695}
696
697void ConvertAddNOperator(const Model& model, const AddNOperator& src_op,
698 GraphDef* tensorflow_graph) {
699 tensorflow::NodeDef* add_op = tensorflow_graph->add_node();
700 add_op->set_op("AddN");
701 add_op->set_name(src_op.outputs[0]);
702 for (const auto& input : src_op.inputs) {
703 *add_op->add_input() = input;
704 }
705 (*add_op->mutable_attr())["N"].set_i(src_op.inputs.size());
706 (*add_op->mutable_attr())["T"].set_type(
707 GetTensorFlowDataType(model, src_op.outputs[0]));
708}
709
710void ConvertMulOperator(const Model& model, const MulOperator& src_op,
711 GraphDef* tensorflow_graph) {
712 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
713 mul_op->set_op("Mul");
714 mul_op->set_name(src_op.outputs[0]);
715 CHECK_EQ(src_op.inputs.size(), 2);
716 *mul_op->add_input() = src_op.inputs[0];
717 *mul_op->add_input() = src_op.inputs[1];
718 (*mul_op->mutable_attr())["T"].set_type(
719 GetTensorFlowDataType(model, src_op.outputs[0]));
720}
721
722void ConvertDivOperator(const Model& model, const DivOperator& src_op,
723 GraphDef* tensorflow_graph) {
724 tensorflow::NodeDef* div_op = tensorflow_graph->add_node();
725 div_op->set_op("Div");
726 div_op->set_name(src_op.outputs[0]);
727 CHECK_EQ(src_op.inputs.size(), 2);
728 *div_op->add_input() = src_op.inputs[0];
729 *div_op->add_input() = src_op.inputs[1];
730 (*div_op->mutable_attr())["T"].set_type(
731 GetTensorFlowDataType(model, src_op.outputs[0]));
732}
733
734void ConvertReluOperator(const Model& model, const ReluOperator& src_op,
735 GraphDef* tensorflow_graph) {
736 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
737 relu_op->set_op("Relu");
738 relu_op->set_name(src_op.outputs[0]);
739 *relu_op->add_input() = src_op.inputs[0];
740 (*relu_op->mutable_attr())["T"].set_type(
741 GetTensorFlowDataType(model, src_op.outputs[0]));
742}
743
744void ConvertRelu1Operator(const Relu1Operator& src_op,
745 GraphDef* tensorflow_graph) {
746 const std::string max_bounds = src_op.outputs[0] + "/max_bounds";
747 const std::string min_bounds = src_op.outputs[0] + "/min_bounds";
748 const std::string max_output = src_op.outputs[0] + "/max_output";
749
750 tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node();
751 max_bounds_const_op->set_op("Const");
752 max_bounds_const_op->set_name(max_bounds);
753 (*max_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
754 auto* max_bounds_const_op_tensor =
755 (*max_bounds_const_op->mutable_attr())["value"].mutable_tensor();
756 max_bounds_const_op_tensor->set_dtype(DT_FLOAT);
757 max_bounds_const_op_tensor->add_float_val(-1.0f);
758
759 tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node();
760 min_bounds_const_op->set_op("Const");
761 min_bounds_const_op->set_name(min_bounds);
762 (*min_bounds_const_op->mutable_attr())["dtype"].set_type(DT_FLOAT);
763 auto* min_bounds_const_op_tensor =
764 (*min_bounds_const_op->mutable_attr())["value"].mutable_tensor();
765 min_bounds_const_op_tensor->set_dtype(DT_FLOAT);
766 min_bounds_const_op_tensor->add_float_val(1.0f);
767
768 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
769 max_op->set_op("Maximum");
770 max_op->set_name(max_output);
771 *max_op->add_input() = src_op.inputs[0];
772 *max_op->add_input() = max_bounds;
773 (*max_op->mutable_attr())["T"].set_type(DT_FLOAT);
774
775 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
776 min_op->set_op("Minimum");
777 min_op->set_name(src_op.outputs[0]);
778 *min_op->add_input() = max_output;
779 *min_op->add_input() = min_bounds;
780 (*min_op->mutable_attr())["T"].set_type(DT_FLOAT);
781}
782
783void ConvertRelu6Operator(const Relu6Operator& src_op,
784 GraphDef* tensorflow_graph) {
785 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
786 relu_op->set_op("Relu6");
787 relu_op->set_name(src_op.outputs[0]);
788 *relu_op->add_input() = src_op.inputs[0];
789 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
790}
791
792void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) {
793 tensorflow::NodeDef* op = tensorflow_graph->add_node();
794 op->set_op("Log");
795 op->set_name(src_op.outputs[0]);
796 CHECK_EQ(src_op.inputs.size(), 1);
797 *op->add_input() = src_op.inputs[0];
798 (*op->mutable_attr())["T"].set_type(DT_FLOAT);
799}
800
801void ConvertLogisticOperator(const LogisticOperator& src_op,
802 GraphDef* tensorflow_graph) {
803 tensorflow::NodeDef* relu_op = tensorflow_graph->add_node();
804 relu_op->set_op("Sigmoid");
805 relu_op->set_name(src_op.outputs[0]);
806 *relu_op->add_input() = src_op.inputs[0];
807 (*relu_op->mutable_attr())["T"].set_type(DT_FLOAT);
808}
809
810void ConvertTanhOperator(const TanhOperator& src_op,
811 GraphDef* tensorflow_graph) {
812 tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node();
813 tanh_op->set_op("Tanh");
814 tanh_op->set_name(src_op.outputs[0]);
815 *tanh_op->add_input() = src_op.inputs[0];
816 (*tanh_op->mutable_attr())["T"].set_type(DT_FLOAT);
817}
818
819void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op,
820 GraphDef* tensorflow_graph) {
821 std::string softmax_input;
822 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
823 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
824 softmax_input = src_op.inputs[0];
825 } else {
826 // Insert a reshape operator that reduces the dimensions down to the 2 that
827 // are required for TensorFlow Logits.
828 const std::string reshape_output =
829 src_op.outputs[0] + "/softmax_insert_reshape";
830 const std::string softmax_size = src_op.outputs[0] + "/softmax_insert_size";
831 softmax_input = reshape_output;
832
833 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
834 reshape_op->set_op("Reshape");
835 reshape_op->set_name(reshape_output);
836 *reshape_op->add_input() = src_op.inputs[0];
837 *reshape_op->add_input() = softmax_size;
838 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
839
840 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
841 int32_t flattened_size = 1;
842 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
843 flattened_size *= input_shape.dims(i);
844 }
845 const std::vector<int32> shape_data = {
846 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
847 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
848 }
849
850 tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node();
851 softmax_op->set_op("Softmax");
852 softmax_op->set_name(src_op.outputs[0]);
853 *softmax_op->add_input() = softmax_input;
854 // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter
855 CHECK_EQ(src_op.beta, 1.f);
856 (*softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
857}
858
859void ConvertLogSoftmaxOperator(const Model& model,
860 const LogSoftmaxOperator& src_op,
861 GraphDef* tensorflow_graph) {
862 std::string softmax_input;
863 Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]);
864 if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) {
865 softmax_input = src_op.inputs[0];
866 } else {
867 // Insert a reshape operator that reduces the dimensions down to the 2 that
868 // are required for TensorFlow Logits.
869 const std::string reshape_output =
870 src_op.outputs[0] + "/log_softmax_insert_reshape";
871 const std::string softmax_size =
872 src_op.outputs[0] + "/log_softmax_insert_size";
873 softmax_input = reshape_output;
874
875 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
876 reshape_op->set_op("Reshape");
877 reshape_op->set_name(reshape_output);
878 *reshape_op->add_input() = src_op.inputs[0];
879 *reshape_op->add_input() = softmax_size;
880 (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
881
882 const auto& input_shape = model.GetArray(src_op.inputs[0]).shape();
883 int32_t flattened_size = 1;
884 for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) {
885 flattened_size *= input_shape.dims(i);
886 }
887 const std::vector<int32> shape_data = {
888 flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)};
889 CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph);
890 }
891
892 tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node();
893 log_softmax_op->set_op("LogSoftmax");
894 log_softmax_op->set_name(src_op.outputs[0]);
895 *log_softmax_op->add_input() = softmax_input;
896 (*log_softmax_op->mutable_attr())["T"].set_type(DT_FLOAT);
897}
898
899void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op,
900 GraphDef* tensorflow_graph) {
901 const std::string square_output = src_op.outputs[0] + "/square";
902 const std::string sum_reduction_indices =
903 src_op.outputs[0] + "/reduction_indices";
904 const std::string sum_output = src_op.outputs[0] + "/sum";
905 const std::string rsqrt_output = src_op.outputs[0] + "/rsqrt";
906 const std::string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled";
907
908 tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node();
909 sum_reduction_indices_op->set_op("Const");
910 sum_reduction_indices_op->set_name(sum_reduction_indices);
911 (*sum_reduction_indices_op->mutable_attr())["dtype"].set_type(DT_INT32);
912 auto* sum_reduction_indices_tensor =
913 (*sum_reduction_indices_op->mutable_attr())["value"].mutable_tensor();
914 sum_reduction_indices_tensor->set_dtype(DT_INT32);
915 auto* sum_reduction_indices_shape =
916 sum_reduction_indices_tensor->mutable_tensor_shape();
917 auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim();
918 sum_reduction_indices_dim->set_size(2);
919 sum_reduction_indices_tensor->add_int_val(0);
920 sum_reduction_indices_tensor->add_int_val(1);
921
922 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
923 square_op->set_op("Square");
924 square_op->set_name(square_output);
925 *square_op->add_input() = src_op.inputs[0];
926 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
927
928 tensorflow::NodeDef* sum_op = tensorflow_graph->add_node();
929 sum_op->set_op("Sum");
930 sum_op->set_name(sum_output);
931 *sum_op->add_input() = square_output;
932 *sum_op->add_input() = sum_reduction_indices;
933 (*sum_op->mutable_attr())["T"].set_type(DT_FLOAT);
934
935 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
936 rsqrt_op->set_op("Rsqrt");
937 rsqrt_op->set_name(rsqrt_output);
938 *rsqrt_op->add_input() = sum_output;
939 (*rsqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
940
941 tensorflow::NodeDef* mul_op = tensorflow_graph->add_node();
942 mul_op->set_op("Mul");
943 mul_op->set_name(src_op.outputs[0]);
944 *mul_op->add_input() = src_op.inputs[0];
945 *mul_op->add_input() = rsqrt_output;
946 (*mul_op->mutable_attr())["T"].set_type(DT_FLOAT);
947}
948
949void ConvertLocalResponseNormalizationOperator(
950 const LocalResponseNormalizationOperator& src_op,
951 GraphDef* tensorflow_graph) {
952 tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node();
953 lrn_op->set_op("LRN");
954 lrn_op->set_name(src_op.outputs[0]);
955 *lrn_op->add_input() = src_op.inputs[0];
956 (*lrn_op->mutable_attr())["depth_radius"].set_i(src_op.range);
957 (*lrn_op->mutable_attr())["bias"].set_f(src_op.bias);
958 (*lrn_op->mutable_attr())["alpha"].set_f(src_op.alpha);
959 (*lrn_op->mutable_attr())["beta"].set_f(src_op.beta);
960}
961
962void ConvertFakeQuantOperator(const FakeQuantOperator& src_op,
963 GraphDef* tensorflow_graph) {
964 tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node();
965 fakequant_op->set_op("FakeQuantWithMinMaxArgs");
966 fakequant_op->set_name(src_op.outputs[0]);
967 CHECK_EQ(src_op.inputs.size(), 1);
968 *fakequant_op->add_input() = src_op.inputs[0];
969 CHECK(src_op.minmax);
970 (*fakequant_op->mutable_attr())["min"].set_f(src_op.minmax->min);
971 (*fakequant_op->mutable_attr())["max"].set_f(src_op.minmax->max);
972 if (src_op.num_bits) {
973 (*fakequant_op->mutable_attr())["num_bits"].set_i(src_op.num_bits);
974 }
975 if (src_op.narrow_range) {
976 (*fakequant_op->mutable_attr())["narrow_range"].set_b(src_op.narrow_range);
977 }
978}
979
980void ConvertMaxPoolOperator(const MaxPoolOperator& src_op,
981 GraphDef* tensorflow_graph) {
982 tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node();
983 maxpool_op->set_op("MaxPool");
984 maxpool_op->set_name(src_op.outputs[0]);
985 *maxpool_op->add_input() = src_op.inputs[0];
986 auto& strides = (*maxpool_op->mutable_attr())["strides"];
987 strides.mutable_list()->add_i(1);
988 strides.mutable_list()->add_i(src_op.stride_height);
989 strides.mutable_list()->add_i(src_op.stride_width);
990 strides.mutable_list()->add_i(1);
991 std::string padding;
992 if (src_op.padding.type == PaddingType::kSame) {
993 padding = "SAME";
994 } else if (src_op.padding.type == PaddingType::kValid) {
995 padding = "VALID";
996 } else {
997 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
998 }
999 (*maxpool_op->mutable_attr())["padding"].set_s(padding);
1000 (*maxpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1001 auto& ksize = (*maxpool_op->mutable_attr())["ksize"];
1002 ksize.mutable_list()->add_i(1);
1003 ksize.mutable_list()->add_i(src_op.kheight);
1004 ksize.mutable_list()->add_i(src_op.kwidth);
1005 ksize.mutable_list()->add_i(1);
1006}
1007
1008void ConvertAveragePoolOperator(const AveragePoolOperator& src_op,
1009 GraphDef* tensorflow_graph) {
1010 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1011 avgpool_op->set_op("AvgPool");
1012 avgpool_op->set_name(src_op.outputs[0]);
1013 *avgpool_op->add_input() = src_op.inputs[0];
1014 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1015 strides.mutable_list()->add_i(1);
1016 strides.mutable_list()->add_i(src_op.stride_height);
1017 strides.mutable_list()->add_i(src_op.stride_width);
1018 strides.mutable_list()->add_i(1);
1019 std::string padding;
1020 if (src_op.padding.type == PaddingType::kSame) {
1021 padding = "SAME";
1022 } else if (src_op.padding.type == PaddingType::kValid) {
1023 padding = "VALID";
1024 } else {
1025 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1026 }
1027 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1028 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1029 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1030 ksize.mutable_list()->add_i(1);
1031 ksize.mutable_list()->add_i(src_op.kheight);
1032 ksize.mutable_list()->add_i(src_op.kwidth);
1033 ksize.mutable_list()->add_i(1);
1034}
1035
1036void ConvertConcatenationOperator(const Model& model,
1037 const ConcatenationOperator& src_op,
1038 GraphDef* tensorflow_graph) {
1039 tensorflow::NodeDef* dc_op = tensorflow_graph->add_node();
1040 dc_op->set_op("ConcatV2");
1041 dc_op->set_name(src_op.outputs[0]);
1042 const std::string dummy_axis = src_op.outputs[0] + "/axis";
1043 CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph);
1044 for (const auto& input : src_op.inputs) {
1045 *dc_op->add_input() = input;
1046 }
1047 *dc_op->add_input() = dummy_axis;
1048 (*dc_op->mutable_attr())["T"].set_type(
1049 GetTensorFlowDataType(model, src_op.inputs[0]));
1050 (*dc_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1051 (*dc_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1052}
1053
1054void ConvertTensorFlowReshapeOperator(const Model& model,
1055 const TensorFlowReshapeOperator& src_op,
1056 GraphDef* tensorflow_graph) {
1057 tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node();
1058 reshape_op->set_op("Reshape");
1059 reshape_op->set_name(src_op.outputs[0]);
1060 CHECK_EQ(src_op.inputs.size(), 2);
1061 *reshape_op->add_input() = src_op.inputs[0];
1062 *reshape_op->add_input() = src_op.inputs[1];
1063 (*reshape_op->mutable_attr())["T"].set_type(
1064 GetTensorFlowDataType(model, src_op.outputs[0]));
1065 const auto& shape_array = model.GetArray(src_op.inputs[1]);
1066 QCHECK(shape_array.data_type == ArrayDataType::kInt32)
1067 << "Only int32 shape is supported.";
1068 QCHECK(shape_array.buffer != nullptr)
1069 << "Shape inferred at runtime is not supported.";
1070 const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1071 CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
1072}
1073
1074void ConvertL2PoolOperator(const L2PoolOperator& src_op,
1075 GraphDef* tensorflow_graph) {
1076 const std::string square_output = src_op.outputs[0] + "/square";
1077 const std::string avgpool_output = src_op.outputs[0] + "/avgpool";
1078
1079 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1080 square_op->set_op("Square");
1081 square_op->set_name(square_output);
1082 *square_op->add_input() = src_op.inputs[0];
1083 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1084
1085 std::string padding;
1086 if (src_op.padding.type == PaddingType::kSame) {
1087 padding = "SAME";
1088 } else if (src_op.padding.type == PaddingType::kValid) {
1089 padding = "VALID";
1090 } else {
1091 LOG(FATAL) << "Bad padding (only SAME and VALID are supported)";
1092 }
1093
1094 tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node();
1095 avgpool_op->set_op("AvgPool");
1096 avgpool_op->set_name(avgpool_output);
1097 *avgpool_op->add_input() = square_output;
1098 auto& strides = (*avgpool_op->mutable_attr())["strides"];
1099 strides.mutable_list()->add_i(1);
1100 strides.mutable_list()->add_i(src_op.stride_height);
1101 strides.mutable_list()->add_i(src_op.stride_width);
1102 strides.mutable_list()->add_i(1);
1103
1104 (*avgpool_op->mutable_attr())["padding"].set_s(padding);
1105 (*avgpool_op->mutable_attr())["T"].set_type(DT_FLOAT);
1106 auto& ksize = (*avgpool_op->mutable_attr())["ksize"];
1107 ksize.mutable_list()->add_i(1);
1108 ksize.mutable_list()->add_i(src_op.kheight);
1109 ksize.mutable_list()->add_i(src_op.kwidth);
1110 ksize.mutable_list()->add_i(1);
1111
1112 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1113 sqrt_op->set_op("Sqrt");
1114 sqrt_op->set_name(src_op.outputs[0]);
1115 *sqrt_op->add_input() = avgpool_output;
1116 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1117}
1118
1119void ConvertSquareOperator(const TensorFlowSquareOperator& src_op,
1120 GraphDef* tensorflow_graph) {
1121 tensorflow::NodeDef* square_op = tensorflow_graph->add_node();
1122 square_op->set_op("Square");
1123 square_op->set_name(src_op.outputs[0]);
1124 CHECK_EQ(src_op.inputs.size(), 1);
1125 *square_op->add_input() = src_op.inputs[0];
1126 (*square_op->mutable_attr())["T"].set_type(DT_FLOAT);
1127}
1128
1129void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op,
1130 GraphDef* tensorflow_graph) {
1131 tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node();
1132 sqrt_op->set_op("Sqrt");
1133 sqrt_op->set_name(src_op.outputs[0]);
1134 CHECK_EQ(src_op.inputs.size(), 1);
1135 *sqrt_op->add_input() = src_op.inputs[0];
1136 (*sqrt_op->mutable_attr())["T"].set_type(DT_FLOAT);
1137}
1138
1139void ConvertRsqrtOperator(const Model& model,
1140 const TensorFlowRsqrtOperator& src_op,
1141 GraphDef* tensorflow_graph) {
1142 tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node();
1143 rsqrt_op->set_op("Rsqrt");
1144 rsqrt_op->set_name(src_op.outputs[0]);
1145 CHECK_EQ(src_op.inputs.size(), 1);
1146 *rsqrt_op->add_input() = src_op.inputs[0];
1147 const tensorflow::DataType data_type =
1148 GetTensorFlowDataType(model, src_op.inputs[0]);
1149 (*rsqrt_op->mutable_attr())["T"].set_type(data_type);
1150}
1151
1152void ConvertSplitOperator(const Model& model,
1153 const TensorFlowSplitOperator& src_op,
1154 GraphDef* tensorflow_graph) {
1155 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1156 split_op->set_op("Split");
1157 split_op->set_name(src_op.outputs[0]);
1158 for (const auto& input : src_op.inputs) {
1159 *split_op->add_input() = input;
1160 }
1161 (*split_op->mutable_attr())["T"].set_type(
1162 GetTensorFlowDataType(model, src_op.outputs[0]));
1163 (*split_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1164 const auto& split_dim_array = model.GetArray(src_op.inputs[0]);
1165 CHECK(split_dim_array.buffer);
1166 CHECK(split_dim_array.data_type == ArrayDataType::kInt32);
1167 const auto& split_dim_data =
1168 split_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
1169 CHECK_EQ(split_dim_data.size(), 1);
1170 const int split_dim = split_dim_data[0];
1171 CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim,
1172 tensorflow_graph);
1173}
1174
1175void ConvertSplitVOperator(const Model& model,
1176 const TensorFlowSplitVOperator& src_op,
1177 GraphDef* tensorflow_graph) {
1178 tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node();
1179 split_v_op->set_op("SplitV");
1180 split_v_op->set_name(src_op.outputs[0]);
1181 for (const auto& input : src_op.inputs) {
1182 *split_v_op->add_input() = input;
1183 }
1184 (*split_v_op->mutable_attr())["T"].set_type(
1185 GetTensorFlowDataType(model, src_op.outputs[0]));
1186 (*split_v_op->mutable_attr())["Tlen"].set_type(
1187 GetTensorFlowDataType(model, src_op.inputs[1]));
1188 (*split_v_op->mutable_attr())["num_split"].set_i(src_op.num_split);
1189 ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph);
1190}
1191
1192void ConvertCastOperator(const Model& model, const CastOperator& src_op,
1193 GraphDef* tensorflow_graph) {
1194 tensorflow::NodeDef* cast_op = tensorflow_graph->add_node();
1195 cast_op->set_op("Cast");
1196 cast_op->set_name(src_op.outputs[0]);
1197 CHECK_EQ(src_op.inputs.size(), 1);
1198 *cast_op->add_input() = src_op.inputs[0];
1199
1200 (*cast_op->mutable_attr())["DstT"].set_type(
1201 GetTensorFlowDataType(model, src_op.outputs[0]));
1202 (*cast_op->mutable_attr())["SrcT"].set_type(
1203 GetTensorFlowDataType(model, src_op.inputs[0]));
1204}
1205
1206void ConvertFloorOperator(const Model& model, const FloorOperator& src_op,
1207 GraphDef* tensorflow_graph) {
1208 tensorflow::NodeDef* floor_op = tensorflow_graph->add_node();
1209 floor_op->set_op("Floor");
1210 floor_op->set_name(src_op.outputs[0]);
1211 CHECK_EQ(src_op.inputs.size(), 1);
1212 *floor_op->add_input() = src_op.inputs[0];
1213 (*floor_op->mutable_attr())["T"].set_type(DT_FLOAT);
1214}
1215
1216void ConvertCeilOperator(const Model& model, const CeilOperator& src_op,
1217 GraphDef* tensorflow_graph) {
1218 tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node();
1219 ceil_op->set_op("Ceil");
1220 ceil_op->set_name(src_op.outputs[0]);
1221 CHECK_EQ(src_op.inputs.size(), 1);
1222 *ceil_op->add_input() = src_op.inputs[0];
1223 (*ceil_op->mutable_attr())["T"].set_type(DT_FLOAT);
1224}
1225
1226void ConvertRoundOperator(const Model& model, const RoundOperator& src_op,
1227 GraphDef* tensorflow_graph) {
1228 tensorflow::NodeDef* round_op = tensorflow_graph->add_node();
1229 round_op->set_op("Round");
1230 round_op->set_name(src_op.outputs[0]);
1231 CHECK_EQ(src_op.inputs.size(), 1);
1232 *round_op->add_input() = src_op.inputs[0];
1233 (*round_op->mutable_attr())["T"].set_type(DT_FLOAT);
1234}
1235
1236void ConvertGatherOperator(const Model& model, const GatherOperator& src_op,
1237 GraphDef* tensorflow_graph) {
1238 tensorflow::NodeDef* gather_op = tensorflow_graph->add_node();
1239 gather_op->set_op("GatherV2");
1240 gather_op->set_name(src_op.outputs[0]);
1241 *gather_op->add_input() = src_op.inputs[0];
1242 *gather_op->add_input() = src_op.inputs[1];
1243
1244 if (!src_op.axis) {
1245 // Dynamic axis.
1246 CHECK_EQ(src_op.inputs.size(), 3);
1247 *gather_op->add_input() = src_op.inputs[2];
1248 } else {
1249 // Constant axis.
1250 CHECK_EQ(src_op.inputs.size(), 2);
1251 const std::string gather_axis =
1252 AvailableArrayName(model, gather_op->name() + "/axis");
1253 CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {},
1254 tensorflow_graph);
1255 *gather_op->add_input() = gather_axis;
1256 }
1257
1258 (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32);
1259 (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32);
1260 const tensorflow::DataType params_type =
1261 GetTensorFlowDataType(model, src_op.inputs[0]);
1262 (*gather_op->mutable_attr())["Tparams"].set_type(params_type);
1263}
1264
1265void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op,
1266 GraphDef* tensorflow_graph) {
1267 tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node();
1268 argmax_op->set_op("ArgMax");
1269 argmax_op->set_name(src_op.outputs[0]);
1270 CHECK_EQ(src_op.inputs.size(), 2);
1271 *argmax_op->add_input() = src_op.inputs[0];
1272 *argmax_op->add_input() = src_op.inputs[1];
1273 (*argmax_op->mutable_attr())["T"].set_type(
1274 GetTensorFlowDataType(model, src_op.inputs[0]));
1275 (*argmax_op->mutable_attr())["Tidx"].set_type(
1276 GetTensorFlowDataType(model, src_op.inputs[1]));
1277 (*argmax_op->mutable_attr())["output_type"].set_type(
1278 GetTensorFlowDataType(model, src_op.outputs[0]));
1279}
1280
1281void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op,
1282 GraphDef* tensorflow_graph) {
1283 tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node();
1284 argmin_op->set_op("ArgMin");
1285 argmin_op->set_name(src_op.outputs[0]);
1286 CHECK_EQ(src_op.inputs.size(), 2);
1287 *argmin_op->add_input() = src_op.inputs[0];
1288 *argmin_op->add_input() = src_op.inputs[1];
1289 (*argmin_op->mutable_attr())["T"].set_type(
1290 GetTensorFlowDataType(model, src_op.inputs[0]));
1291 (*argmin_op->mutable_attr())["Tidx"].set_type(
1292 GetTensorFlowDataType(model, src_op.inputs[1]));
1293 (*argmin_op->mutable_attr())["output_type"].set_type(
1294 GetTensorFlowDataType(model, src_op.outputs[0]));
1295}
1296
1297void ConvertTransposeOperator(const Model& model,
1298 const TransposeOperator& src_op,
1299 GraphDef* tensorflow_graph) {
1300 tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node();
1301 transpose_op->set_op("Transpose");
1302 transpose_op->set_name(src_op.outputs[0]);
1303 CHECK_EQ(src_op.inputs.size(), 2);
1304 *transpose_op->add_input() = src_op.inputs[0];
1305 *transpose_op->add_input() = src_op.inputs[1];
1306 (*transpose_op->mutable_attr())["T"].set_type(
1307 GetTensorFlowDataType(model, src_op.inputs[0]));
1308 (*transpose_op->mutable_attr())["Tperm"].set_type(
1309 GetTensorFlowDataType(model, src_op.inputs[1]));
1310}
1311
1312void ConvertTensorFlowShapeOperator(const Model& model,
1313 const TensorFlowShapeOperator& src_op,
1314 GraphDef* tensorflow_graph) {
1315 tensorflow::NodeDef* shape_op = tensorflow_graph->add_node();
1316 shape_op->set_op("Shape");
1317 shape_op->set_name(src_op.outputs[0]);
1318 CHECK_EQ(src_op.inputs.size(), 1);
1319 *shape_op->add_input() = src_op.inputs[0];
1320 (*shape_op->mutable_attr())["T"].set_type(
1321 GetTensorFlowDataType(model, src_op.inputs[0]));
1322 (*shape_op->mutable_attr())["out_type"].set_type(
1323 GetTensorFlowDataType(model, src_op.outputs[0]));
1324}
1325
1326void ConvertRankOperator(const Model& model,
1327 const TensorFlowRankOperator& src_op,
1328 GraphDef* tensorflow_graph) {
1329 tensorflow::NodeDef* rank_op = tensorflow_graph->add_node();
1330 rank_op->set_op("Rank");
1331 rank_op->set_name(src_op.outputs[0]);
1332 CHECK_EQ(src_op.inputs.size(), 1);
1333 *rank_op->add_input() = src_op.inputs[0];
1334 (*rank_op->mutable_attr())["T"].set_type(
1335 GetTensorFlowDataType(model, src_op.inputs[0]));
1336}
1337
1338void ConvertRangeOperator(const Model& model, const RangeOperator& src_op,
1339 GraphDef* tensorflow_graph) {
1340 tensorflow::NodeDef* range_op = tensorflow_graph->add_node();
1341 range_op->set_op("Range");
1342 range_op->set_name(src_op.outputs[0]);
1343 CHECK_EQ(src_op.inputs.size(), 3);
1344 *range_op->add_input() = src_op.inputs[0];
1345 *range_op->add_input() = src_op.inputs[1];
1346 *range_op->add_input() = src_op.inputs[2];
1347 (*range_op->mutable_attr())["Tidx"].set_type(
1348 GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
1349}
1350
1351void ConvertPackOperator(const Model& model, const PackOperator& src_op,
1352 GraphDef* tensorflow_graph) {
1353 tensorflow::NodeDef* pack_op = tensorflow_graph->add_node();
1354 pack_op->set_op("Pack");
1355 pack_op->set_name(src_op.outputs[0]);
1356 for (const auto& input : src_op.inputs) {
1357 *pack_op->add_input() = input;
1358 }
1359 (*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
1360 (*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
1361 (*pack_op->mutable_attr())["T"].set_type(
1362 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1363}
1364
1365void ConvertFillOperator(const Model& model, const FillOperator& src_op,
1366 GraphDef* tensorflow_graph) {
1367 tensorflow::NodeDef* fill_op = tensorflow_graph->add_node();
1368 fill_op->set_op("Fill");
1369 fill_op->set_name(src_op.outputs[0]);
1370 CHECK_EQ(src_op.inputs.size(), 2);
1371 *fill_op->add_input() = src_op.inputs[0];
1372 *fill_op->add_input() = src_op.inputs[1];
1373 (*fill_op->mutable_attr())["index_type"].set_type(
1374 GetTensorFlowDataType(model, src_op.inputs[0]));
1375 (*fill_op->mutable_attr())["T"].set_type(
1376 GetTensorFlowDataType(model, src_op.inputs[1]));
1377}
1378
1379void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op,
1380 GraphDef* tensorflow_graph) {
1381 tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node();
1382 floor_div_op->set_op("FloorDiv");
1383 floor_div_op->set_name(src_op.outputs[0]);
1384 CHECK_EQ(src_op.inputs.size(), 2);
1385 *floor_div_op->add_input() = src_op.inputs[0];
1386 *floor_div_op->add_input() = src_op.inputs[1];
1387 (*floor_div_op->mutable_attr())["T"].set_type(
1388 GetTensorFlowDataType(model, src_op.inputs[0]));
1389}
1390
1391void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op,
1392 GraphDef* tensorflow_graph) {
1393 tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node();
1394 floor_mod_op->set_op("FloorMod");
1395 floor_mod_op->set_name(src_op.outputs[0]);
1396 DCHECK_EQ(src_op.inputs.size(), 2);
1397 *floor_mod_op->add_input() = src_op.inputs[0];
1398 *floor_mod_op->add_input() = src_op.inputs[1];
1399 (*floor_mod_op->mutable_attr())["T"].set_type(
1400 GetTensorFlowDataType(model, src_op.inputs[0]));
1401}
1402
1403void ConvertExpandDimsOperator(const Model& model,
1404 const ExpandDimsOperator& src_op,
1405 GraphDef* tensorflow_graph) {
1406 tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node();
1407 expand_dims_op->set_op("ExpandDims");
1408 expand_dims_op->set_name(src_op.outputs[0]);
1409 CHECK_EQ(src_op.inputs.size(), 2);
1410 *expand_dims_op->add_input() = src_op.inputs[0];
1411 *expand_dims_op->add_input() = src_op.inputs[1];
1412 (*expand_dims_op->mutable_attr())["T"].set_type(
1413 GetTensorFlowDataType(model, src_op.inputs[0]));
1414 (*expand_dims_op->mutable_attr())["Tdim"].set_type(
1415 GetTensorFlowDataType(model, src_op.inputs[1]));
1416}
1417
1418void ConvertResizeBilinearOperator(const Model& model,
1419 const ResizeBilinearOperator& src_op,
1420 GraphDef* tensorflow_graph) {
1421 tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1422 resize_op->set_op("ResizeBilinear");
1423 resize_op->set_name(src_op.outputs[0]);
1424 CHECK_EQ(src_op.inputs.size(), 2);
1425 *resize_op->add_input() = src_op.inputs[0];
1426 *resize_op->add_input() = src_op.inputs[1];
1427 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1428 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1429 (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1430 src_op.half_pixel_centers);
1431}
1432
1433void ConvertResizeNearestNeighborOperator(
1434 const Model& model, const ResizeNearestNeighborOperator& src_op,
1435 GraphDef* tensorflow_graph) {
1436 tensorflow::NodeDef* resize_op = tensorflow_graph->add_node();
1437 resize_op->set_op("ResizeNearestNeighbor");
1438 resize_op->set_name(src_op.outputs[0]);
1439 CHECK_EQ(src_op.inputs.size(), 2);
1440 *resize_op->add_input() = src_op.inputs[0];
1441 *resize_op->add_input() = src_op.inputs[1];
1442 (*resize_op->mutable_attr())["T"].set_type(DT_FLOAT);
1443 (*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
1444 (*resize_op->mutable_attr())["half_pixel_centers"].set_b(
1445 src_op.half_pixel_centers);
1446}
1447
1448void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
1449 GraphDef* tensorflow_graph) {
1450 tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
1451 onehot_op->set_op("OneHot");
1452 onehot_op->set_name(src_op.outputs[0]);
1453 CHECK_EQ(src_op.inputs.size(), 4);
1454 for (const auto& input : src_op.inputs) {
1455 *onehot_op->add_input() = input;
1456 }
1457 (*onehot_op->mutable_attr())["T"].set_type(
1458 GetTensorFlowDataType(model, src_op.outputs[0]));
1459 (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
1460}
1461
1462namespace {
1463// TODO(aselle): Remove when available in absl
1464absl::string_view FindLongestCommonPrefix(absl::string_view a,
1465 absl::string_view b) {
1466 if (a.empty() || b.empty()) return absl::string_view();
1467
1468 const char* pa = a.data();
1469 const char* pb = b.data();
1470 std::string::difference_type count = 0;
1471 const std::string::difference_type limit = std::min(a.size(), b.size());
1472 while (count < limit && *pa == *pb) {
1473 ++pa;
1474 ++pb;
1475 ++count;
1476 }
1477
1478 return absl::string_view(a.data(), count);
1479}
1480} // namespace
1481
1482void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op,
1483 GraphDef* tensorflow_graph) {
1484 // Find the base name
1485 const std::string base(
1486 FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT],
1487 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]));
1488
1489 // Concatenate inputs
1490 const std::string concat_output = base + "basic_lstm_cell/concat";
1491 // Op names have been chosen to match the tf.slim LSTM naming
1492 // as closely as possible.
1493 const int axis =
1494 model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT])
1495 .shape()
1496 .dimensions_count() -
1497 1;
1498 // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat
1499 // works the same since the tensor has the same underlying data layout.
1500 const std::string axis_output = concat_output + "/axis";
1501 CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph);
1502 tensorflow::NodeDef* concat_op = tensorflow_graph->add_node();
1503 concat_op->set_op("ConcatV2");
1504 concat_op->set_name(concat_output);
1505 *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT];
1506 *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT];
1507 *concat_op->add_input() = axis_output;
1508 (*concat_op->mutable_attr())["T"].set_type(DT_FLOAT);
1509 (*concat_op->mutable_attr())["Tidx"].set_type(DT_INT32);
1510 (*concat_op->mutable_attr())["N"].set_i(2); // Number of inputs
1511
1512 // Write weights
1513 const std::string weights_output = base + "weights";
1514 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]));
1515 const std::string weights_name = WalkUpToConstantArray(
1516 model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]);
1517 const auto& weights_array = model.GetArray(weights_name);
1518 // Convert 4D FullyConnected weights into 2D matrix
1519 const auto& weights_shape = weights_array.shape();
1520 CHECK_EQ(weights_shape.dimensions_count(), 2);
1521 CHECK(weights_array.buffer);
1522 CHECK(weights_array.buffer->type == ArrayDataType::kFloat);
1523 const float* weights_data =
1524 weights_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1525 ConvertFloatTensorConst(weights_output, weights_shape, weights_data,
1526 AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph);
1527
1528 // Fully connected matrix multiply
1529 const std::string matmul_output = base + "MatMul";
1530 tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node();
1531 matmul_op->set_op("MatMul");
1532 matmul_op->set_name(matmul_output);
1533 *matmul_op->add_input() = concat_output;
1534 *matmul_op->add_input() = weights_output;
1535 (*matmul_op->mutable_attr())["transpose_a"].set_b(false);
1536 (*matmul_op->mutable_attr())["transpose_b"].set_b(false);
1537 (*matmul_op->mutable_attr())["T"].set_type(DT_FLOAT);
1538
1539 // Write biases
1540 const std::string biases_output = base + "biases";
1541 CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]));
1542 const std::string bias_name = WalkUpToConstantArray(
1543 model, src_op.inputs[LstmCellOperator::BIASES_INPUT]);
1544 const auto& bias_array = model.GetArray(bias_name);
1545 // TODO(b/62904716) Bias arrays should be 1-D, and used directly.
1546 Shape bias_shape_1d = bias_array.shape();
1547 UnextendShape(&bias_shape_1d, 1);
1548 CHECK(bias_array.buffer);
1549 CHECK(bias_array.buffer->type == ArrayDataType::kFloat);
1550 const float* bias_data =
1551 bias_array.GetBuffer<ArrayDataType::kFloat>().data.data();
1552 ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data,
1553 AxesOrder::kOneAxis, AxesOrder::kOneAxis,
1554 tensorflow_graph,
1555 LegacyScalarPolicy::kDoCreateLegacyScalars);
1556
1557 // Add biases
1558 std::string biasadd_output = base + "BiasAdd";
1559 tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node();
1560 biasadd_op->set_op("BiasAdd");
1561 biasadd_op->set_name(biasadd_output);
1562 biasadd_op->add_input(matmul_output);
1563 biasadd_op->add_input(biases_output);
1564 (*biasadd_op->mutable_attr())["data_format"].set_s("NHWC");
1565 (*biasadd_op->mutable_attr())["T"].set_type(DT_FLOAT);
1566
1567 // Split
1568 std::string split_dim_output = base + "split/split_dim";
1569 // The dimension is the same as the concatenation dimension
1570 CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph);
1571 std::string split_output = base + "split";
1572 tensorflow::NodeDef* split_op = tensorflow_graph->add_node();
1573 split_op->set_op("Split");
1574 split_op->set_name(split_output);
1575 *split_op->add_input() = split_dim_output;
1576 *split_op->add_input() = biasadd_output;
1577 (*split_op->mutable_attr())["T"].set_type(DT_FLOAT);
1578 (*split_op->mutable_attr())["num_split"].set_i(4); // Split into four outputs
1579
1580 // Activation functions and memory computations
1581 const std::string tanh_0_output = base + "Tanh";
1582 tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node();
1583 tanh_0_op->set_op("Tanh");
1584 tanh_0_op->set_name(tanh_0_output);
1585 *tanh_0_op->add_input() = split_output + ":1";
1586 (*tanh_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1587
1588 const std::string sigmoid_1_output = base + "Sigmoid_1";
1589 tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node();
1590 logistic_1_op->set_op("Sigmoid");
1591 logistic_1_op->set_name(sigmoid_1_output);
1592 *logistic_1_op->add_input() = split_output;
1593 (*logistic_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1594
1595 const std::string mul_1_output = base + "mul_1";
1596 tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node();
1597 mul_1_op->set_op("Mul");
1598 mul_1_op->set_name(mul_1_output);
1599 *mul_1_op->add_input() = sigmoid_1_output;
1600 *mul_1_op->add_input() = tanh_0_output;
1601 (*mul_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1602
1603 const std::string sigmoid_0_output = base + "Sigmoid";
1604 tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node();
1605 logistic_2_op->set_op("Sigmoid");
1606 logistic_2_op->set_name(sigmoid_0_output);
1607 *logistic_2_op->add_input() = split_output + ":2";
1608 (*logistic_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1609
1610 const std::string sigmoid_2_output = base + "Sigmoid_2";
1611 tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node();
1612 logistic_3_op->set_op("Sigmoid");
1613 logistic_3_op->set_name(sigmoid_2_output);
1614 *logistic_3_op->add_input() = split_output + ":3";
1615 (*logistic_3_op->mutable_attr())["T"].set_type(DT_FLOAT);
1616
1617 const std::string mul_0_output = base + "mul";
1618 tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node();
1619 mul_0_op->set_op("Mul");
1620 mul_0_op->set_name(mul_0_output);
1621 *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT];
1622 *mul_0_op->add_input() = sigmoid_0_output;
1623 (*mul_0_op->mutable_attr())["T"].set_type(DT_FLOAT);
1624
1625 const std::string add_1_output =
1626 src_op.outputs[LstmCellOperator::STATE_OUTPUT];
1627 tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node();
1628 add_1_op->set_op("Add");
1629 add_1_op->set_name(add_1_output);
1630 *add_1_op->add_input() = mul_0_output;
1631 *add_1_op->add_input() = mul_1_output;
1632 (*add_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1633
1634 const std::string tanh_1_output = base + "Tanh_1";
1635 tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node();
1636 tanh_1_op->set_op("Tanh");
1637 tanh_1_op->set_name(tanh_1_output);
1638 *tanh_1_op->add_input() = add_1_output;
1639 (*tanh_1_op->mutable_attr())["T"].set_type(DT_FLOAT);
1640
1641 const std::string mul_2_output =
1642 src_op.outputs[LstmCellOperator::ACTIV_OUTPUT];
1643 tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node();
1644 mul_2_op->set_op("Mul");
1645 mul_2_op->set_name(mul_2_output);
1646 *mul_2_op->add_input() = tanh_1_output;
1647 *mul_2_op->add_input() = sigmoid_2_output;
1648 (*mul_2_op->mutable_attr())["T"].set_type(DT_FLOAT);
1649}
1650
1651void ConvertSpaceToBatchNDOperator(const Model& model,
1652 const SpaceToBatchNDOperator& src_op,
1653 GraphDef* tensorflow_graph) {
1654 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1655 new_op->set_op("SpaceToBatchND");
1656 new_op->set_name(src_op.outputs[0]);
1657 CHECK_EQ(src_op.inputs.size(), 3);
1658 *new_op->add_input() = src_op.inputs[0];
1659 *new_op->add_input() = src_op.inputs[1];
1660 *new_op->add_input() = src_op.inputs[2];
1661 const tensorflow::DataType params_type =
1662 GetTensorFlowDataType(model, src_op.inputs[0]);
1663 (*new_op->mutable_attr())["T"].set_type(params_type);
1664 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1665 (*new_op->mutable_attr())["Tpaddings"].set_type(DT_INT32);
1666}
1667
1668void ConvertBatchToSpaceNDOperator(const Model& model,
1669 const BatchToSpaceNDOperator& src_op,
1670 GraphDef* tensorflow_graph) {
1671 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1672 new_op->set_op("BatchToSpaceND");
1673 new_op->set_name(src_op.outputs[0]);
1674 CHECK_EQ(src_op.inputs.size(), 3);
1675 *new_op->add_input() = src_op.inputs[0];
1676 *new_op->add_input() = src_op.inputs[1];
1677 *new_op->add_input() = src_op.inputs[2];
1678 const tensorflow::DataType params_type =
1679 GetTensorFlowDataType(model, src_op.inputs[0]);
1680 (*new_op->mutable_attr())["T"].set_type(params_type);
1681 (*new_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32);
1682 (*new_op->mutable_attr())["Tcrops"].set_type(DT_INT32);
1683}
1684
1685void ConvertPadOperator(const Model& model, const PadOperator& src_op,
1686 GraphDef* tensorflow_graph) {
1687 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1688 new_op->set_op("Pad");
1689 new_op->set_name(src_op.outputs[0]);
1690 CHECK_EQ(src_op.inputs.size(), 2);
1691 *new_op->add_input() = src_op.inputs[0];
1692 *new_op->add_input() = src_op.inputs[1];
1693
1694 const tensorflow::DataType params_type =
1695 GetTensorFlowDataType(model, src_op.inputs[0]);
1696 (*new_op->mutable_attr())["T"].set_type(params_type);
1697
1698 // Create the params tensor.
1699 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1700 params_op->set_op("Const");
1701 params_op->set_name(src_op.inputs[1]);
1702 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1703 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1704 tensor->set_dtype(DT_INT32);
1705
1706 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1707 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1708 tensor->add_int_val(src_op.left_padding[i]);
1709 tensor->add_int_val(src_op.right_padding[i]);
1710 }
1711 auto* shape = tensor->mutable_tensor_shape();
1712 shape->add_dim()->set_size(src_op.left_padding.size());
1713 shape->add_dim()->set_size(2);
1714}
1715
1716void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op,
1717 GraphDef* tensorflow_graph) {
1718 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1719 new_op->set_op("PadV2");
1720 new_op->set_name(src_op.outputs[0]);
1721 CHECK_EQ(src_op.inputs.size(), 2);
1722 *new_op->add_input() = src_op.inputs[0];
1723 *new_op->add_input() = src_op.inputs[1];
1724 *new_op->add_input() = src_op.inputs[2];
1725
1726 const tensorflow::DataType params_type =
1727 GetTensorFlowDataType(model, src_op.inputs[0]);
1728 (*new_op->mutable_attr())["T"].set_type(params_type);
1729
1730 // Create the params tensor.
1731 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1732 params_op->set_op("Const");
1733 params_op->set_name(src_op.inputs[1]);
1734 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1735 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1736 tensor->set_dtype(DT_INT32);
1737
1738 CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size());
1739 for (int i = 0; i < src_op.left_padding.size(); ++i) {
1740 tensor->add_int_val(src_op.left_padding[i]);
1741 tensor->add_int_val(src_op.right_padding[i]);
1742 }
1743 auto* shape = tensor->mutable_tensor_shape();
1744 shape->add_dim()->set_size(src_op.left_padding.size());
1745 shape->add_dim()->set_size(2);
1746}
1747
1748void CreateSliceInput(const std::string& input_name,
1749 const std::vector<int>& values,
1750 GraphDef* tensorflow_graph) {
1751 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1752 params_op->set_op("Const");
1753 params_op->set_name(input_name);
1754 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1755 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1756 tensor->set_dtype(DT_INT32);
1757
1758 for (int i = 0; i < values.size(); ++i) {
1759 tensor->add_int_val(values[i]);
1760 }
1761 auto* shape = tensor->mutable_tensor_shape();
1762 shape->add_dim()->set_size(values.size());
1763}
1764
1765void ConvertStridedSliceOperator(const Model& model,
1766 const StridedSliceOperator& src_op,
1767 GraphDef* tensorflow_graph) {
1768 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1769 new_op->set_op("StridedSlice");
1770 new_op->set_name(src_op.outputs[0]);
1771 CHECK_EQ(src_op.inputs.size(), 4);
1772 *new_op->add_input() = src_op.inputs[0];
1773 *new_op->add_input() = src_op.inputs[1];
1774 *new_op->add_input() = src_op.inputs[2];
1775 *new_op->add_input() = src_op.inputs[3];
1776
1777 const tensorflow::DataType params_type =
1778 GetTensorFlowDataType(model, src_op.inputs[0]);
1779 (*new_op->mutable_attr())["T"].set_type(params_type);
1780
1781 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1782 (*new_op->mutable_attr())["begin_mask"].set_i(src_op.begin_mask);
1783 (*new_op->mutable_attr())["ellipsis_mask"].set_i(src_op.ellipsis_mask);
1784 (*new_op->mutable_attr())["end_mask"].set_i(src_op.end_mask);
1785 (*new_op->mutable_attr())["new_axis_mask"].set_i(src_op.new_axis_mask);
1786 (*new_op->mutable_attr())["shrink_axis_mask"].set_i(src_op.shrink_axis_mask);
1787
1788 // Create tensors for start/stop indices and strides.
1789 CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph);
1790 CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph);
1791 CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph);
1792}
1793
1794void ConvertSliceOperator(const Model& model, const SliceOperator& src_op,
1795 GraphDef* tensorflow_graph) {
1796 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1797 new_op->set_op("Slice");
1798 new_op->set_name(src_op.outputs[0]);
1799 CHECK_EQ(src_op.inputs.size(), 3);
1800 *new_op->add_input() = src_op.inputs[0];
1801 *new_op->add_input() = src_op.inputs[1];
1802 *new_op->add_input() = src_op.inputs[2];
1803
1804 const tensorflow::DataType params_type =
1805 GetTensorFlowDataType(model, src_op.inputs[0]);
1806 (*new_op->mutable_attr())["T"].set_type(params_type);
1807 (*new_op->mutable_attr())["Index"].set_type(DT_INT32);
1808
1809 // Create tensors for begin and size inputs.
1810 CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph);
1811 CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph);
1812}
1813
1814template <typename T>
1815void ConvertReduceOperator(const Model& model, const T& src_op,
1816 GraphDef* tensorflow_graph,
1817 const std::string& op_name) {
1818 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1819 new_op->set_op(op_name);
1820 new_op->set_name(src_op.outputs[0]);
1821 CHECK_EQ(src_op.inputs.size(), 2);
1822 *new_op->add_input() = src_op.inputs[0];
1823 *new_op->add_input() = src_op.inputs[1];
1824
1825 if (src_op.type != OperatorType::kAny) {
1826 const tensorflow::DataType params_type =
1827 GetTensorFlowDataType(model, src_op.inputs[0]);
1828 (*new_op->mutable_attr())["T"].set_type(params_type);
1829 }
1830 const tensorflow::DataType indices_type =
1831 GetTensorFlowDataType(model, src_op.inputs[1]);
1832 (*new_op->mutable_attr())["Tidx"].set_type(indices_type);
1833
1834 if (src_op.keep_dims) {
1835 (*new_op->mutable_attr())["keep_dims"].set_b(true);
1836 }
1837
1838 // Create the params tensor.
1839 tensorflow::NodeDef* params_op = tensorflow_graph->add_node();
1840 params_op->set_op("Const");
1841 params_op->set_name(src_op.inputs[1]);
1842 (*params_op->mutable_attr())["dtype"].set_type(DT_INT32);
1843 auto* tensor = (*params_op->mutable_attr())["value"].mutable_tensor();
1844 tensor->set_dtype(DT_INT32);
1845
1846 for (int i = 0; i < src_op.axis.size(); ++i) {
1847 tensor->add_int_val(src_op.axis[i]);
1848 }
1849 auto* shape = tensor->mutable_tensor_shape();
1850 shape->add_dim()->set_size(src_op.axis.size());
1851}
1852
1853void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op,
1854 GraphDef* tensorflow_graph) {
1855 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1856 new_op->set_op("Squeeze");
1857 new_op->set_name(src_op.outputs[0]);
1858 CHECK_EQ(src_op.inputs.size(), 1);
1859 *new_op->add_input() = src_op.inputs[0];
1860
1861 const tensorflow::DataType params_type =
1862 GetTensorFlowDataType(model, src_op.inputs[0]);
1863 (*new_op->mutable_attr())["T"].set_type(params_type);
1864
1865 if (!src_op.squeeze_dims.empty()) {
1866 auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims"];
1867 for (int i : src_op.squeeze_dims) {
1868 squeeze_dims.mutable_list()->add_i(i);
1869 }
1870 }
1871}
1872
1873void ConvertSubOperator(const Model& model, const SubOperator& src_op,
1874 GraphDef* tensorflow_graph) {
1875 tensorflow::NodeDef* sub_op = tensorflow_graph->add_node();
1876 sub_op->set_op("Sub");
1877 sub_op->set_name(src_op.outputs[0]);
1878 CHECK_EQ(src_op.inputs.size(), 2);
1879 *sub_op->add_input() = src_op.inputs[0];
1880 *sub_op->add_input() = src_op.inputs[1];
1881 const tensorflow::DataType data_type =
1882 GetTensorFlowDataType(model, src_op.inputs[0]);
1883 (*sub_op->mutable_attr())["T"].set_type(data_type);
1884}
1885
1886void ConvertTensorFlowMinimumOperator(const Model& model,
1887 const TensorFlowMinimumOperator& src_op,
1888 GraphDef* tensorflow_graph) {
1889 tensorflow::NodeDef* min_op = tensorflow_graph->add_node();
1890 min_op->set_op("Minimum");
1891 min_op->set_name(src_op.outputs[0]);
1892 CHECK_EQ(src_op.inputs.size(), 2);
1893 *min_op->add_input() = src_op.inputs[0];
1894 *min_op->add_input() = src_op.inputs[1];
1895 const tensorflow::DataType data_type =
1896 GetTensorFlowDataType(model, src_op.inputs[0]);
1897 (*min_op->mutable_attr())["T"].set_type(data_type);
1898}
1899
1900void ConvertTensorFlowMaximumOperator(const Model& model,
1901 const TensorFlowMaximumOperator& src_op,
1902 GraphDef* tensorflow_graph) {
1903 tensorflow::NodeDef* max_op = tensorflow_graph->add_node();
1904 max_op->set_op("Maximum");
1905 max_op->set_name(src_op.outputs[0]);
1906 CHECK_EQ(src_op.inputs.size(), 2);
1907 *max_op->add_input() = src_op.inputs[0];
1908 *max_op->add_input() = src_op.inputs[1];
1909 const tensorflow::DataType data_type =
1910 GetTensorFlowDataType(model, src_op.inputs[0]);
1911 (*max_op->mutable_attr())["T"].set_type(data_type);
1912}
1913
1914void ConvertSelectOperator(const Model& model, const SelectOperator& src_op,
1915 GraphDef* tensorflow_graph) {
1916 tensorflow::NodeDef* select_op = tensorflow_graph->add_node();
1917 select_op->set_op("Select");
1918 select_op->set_name(src_op.outputs[0]);
1919 CHECK_EQ(src_op.inputs.size(), 3);
1920 *select_op->add_input() = src_op.inputs[0];
1921 *select_op->add_input() = src_op.inputs[1];
1922 *select_op->add_input() = src_op.inputs[2];
1923 const tensorflow::DataType data_type =
1924 GetTensorFlowDataType(model, src_op.inputs[1]);
1925 (*select_op->mutable_attr())["T"].set_type(data_type);
1926}
1927
1928void ConvertTileOperator(const Model& model,
1929 const TensorFlowTileOperator& src_op,
1930 GraphDef* tensorflow_graph) {
1931 tensorflow::NodeDef* tile_op = tensorflow_graph->add_node();
1932 tile_op->set_op("Tile");
1933 tile_op->set_name(src_op.outputs[0]);
1934 CHECK_EQ(src_op.inputs.size(), 2);
1935 *tile_op->add_input() = src_op.inputs[0];
1936 *tile_op->add_input() = src_op.inputs[1];
1937 const tensorflow::DataType data_type =
1938 GetTensorFlowDataType(model, src_op.inputs[0]);
1939 (*tile_op->mutable_attr())["T"].set_type(data_type);
1940 const tensorflow::DataType multiples_data_type =
1941 GetTensorFlowDataType(model, src_op.inputs[1]);
1942 (*tile_op->mutable_attr())["Tmultiples"].set_type(multiples_data_type);
1943}
1944
1945void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op,
1946 GraphDef* tensorflow_graph) {
1947 tensorflow::NodeDef* topk_op = tensorflow_graph->add_node();
1948 topk_op->set_op("TopKV2");
1949 topk_op->set_name(src_op.outputs[0]);
1950 CHECK_EQ(src_op.inputs.size(), 2);
1951 *topk_op->add_input() = src_op.inputs[0];
1952 *topk_op->add_input() = src_op.inputs[1];
1953 const tensorflow::DataType data_type =
1954 GetTensorFlowDataType(model, src_op.inputs[0]);
1955 (*topk_op->mutable_attr())["T"].set_type(data_type);
1956 (*topk_op->mutable_attr())["sorted"].set_b(true);
1957}
1958
1959void ConvertRandomUniformOperator(const Model& model,
1960 const RandomUniformOperator& src_op,
1961 GraphDef* tensorflow_graph) {
1962 CHECK(tensorflow_graph != nullptr);
1963 tensorflow::NodeDef* new_op = tensorflow_graph->add_node();
1964 new_op->set_op("RandomUniform");
1965 CHECK_EQ(src_op.inputs.size(), 1);
1966 new_op->set_name(src_op.outputs[0]);
1967 *new_op->add_input() = src_op.inputs[0];
1968 const tensorflow::DataType shape_type =
1969 GetTensorFlowDataType(model, src_op.inputs[0]);
1970 (*new_op->mutable_attr())["T"].set_type(shape_type);
1971 (*new_op->mutable_attr())["dtype"].set_type(
1972 GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
1973 (*new_op->mutable_attr())["seed"].set_i(src_op.seed);
1974 (*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
1975}
1976
1977void ConvertComparisonOperator(const Model& model, const Operator& src_op,
1978 const char* op_name,
1979 GraphDef* tensorflow_graph) {
1980 tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node();
1981 comparison_op->set_op(op_name);
1982 comparison_op->set_name(src_op.outputs[0]);
1983 CHECK_EQ(src_op.inputs.size(), 2);
1984 *comparison_op->add_input() = src_op.inputs[0];
1985 *comparison_op->add_input() = src_op.inputs[1];
1986 const tensorflow::DataType data_type =
1987 GetTensorFlowDataType(model, src_op.inputs[0]);
1988 (*comparison_op->mutable_attr())["T"].set_type(data_type);
1989}
1990
1991void ConvertSparseToDenseOperator(const Model& model,
1992 const SparseToDenseOperator& src_op,
1993 const char* op_name,
1994 GraphDef* tensorflow_graph) {
1995 tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node();
1996 sparse_to_dense_op->set_op(op_name);
1997 sparse_to_dense_op->set_name(src_op.outputs[0]);
1998 CHECK_EQ(src_op.inputs.size(), 4);
1999 for (int i = 0; i < 4; ++i) {
2000 *sparse_to_dense_op->add_input() = src_op.inputs[i];
2001 }
2002 const tensorflow::DataType data_type =
2003 GetTensorFlowDataType(model, src_op.inputs[3]);
2004 (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type);
2005 const tensorflow::DataType index_type =
2006 GetTensorFlowDataType(model, src_op.inputs[0]);
2007 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type);
2008 (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b(
2009 src_op.validate_indices);
2010}
2011
2012void ConvertPowOperator(const Model& model, const PowOperator& src_op,
2013 const char* op_name, GraphDef* tensorflow_graph) {
2014 tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
2015 pow_op->set_op(op_name);
2016 pow_op->set_name(src_op.outputs[0]);
2017 CHECK_EQ(src_op.inputs.size(), 2);
2018 for (int i = 0; i < 2; ++i) {
2019 *pow_op->add_input() = src_op.inputs[i];
2020 }
2021 const tensorflow::DataType data_type =
2022 GetTensorFlowDataType(model, src_op.inputs[0]);
2023 (*pow_op->mutable_attr())["T"].set_type(data_type);
2024}
2025
2026void ConvertLogicalAndOperator(const Model& model,
2027 const LogicalAndOperator& src_op,
2028 GraphDef* tensorflow_graph) {
2029 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2030 logical_op->set_op("LogicalAnd");
2031 logical_op->set_name(src_op.outputs[0]);
2032 CHECK_EQ(src_op.inputs.size(), 2);
2033 for (int i = 0; i < 2; ++i) {
2034 *logical_op->add_input() = src_op.inputs[i];
2035 }
2036}
2037
2038void ConvertLogicalNotOperator(const Model& model,
2039 const LogicalNotOperator& src_op,
2040 GraphDef* tensorflow_graph) {
2041 tensorflow::NodeDef* logical_op = tensorflow_graph->add_node();
2042 logical_op->set_op("LogicalNot");
2043 logical_op->set_name(src_op.outputs[0]);
2044 CHECK_EQ(src_op.inputs.size(), 1);
2045 *logical_op->add_input() = src_op.inputs[0];
2046}
2047
2048void ConvertLogicalOrOperator(const Model& model,
2049 const LogicalOrOperator& src_op,
2050 const char* op_name, GraphDef* tensorflow_graph) {
2051 tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node();
2052 logical_or_op->set_op(op_name);
2053 logical_or_op->set_name(src_op.outputs[0]);
2054 CHECK_EQ(src_op.inputs.size(), 2);
2055 for (int i = 0; i < 2; ++i) {
2056 *logical_or_op->add_input() = src_op.inputs[i];
2057 }
2058 const tensorflow::DataType data_type =
2059 GetTensorFlowDataType(model, src_op.inputs[0]);
2060 (*logical_or_op->mutable_attr())["T"].set_type(data_type);
2061}
2062
2063void ConvertCTCBeamSearchDecoderOperator(
2064 const Model& model, const CTCBeamSearchDecoderOperator& src_op,
2065 const char* op_name, GraphDef* tensorflow_graph) {
2066 auto* op = tensorflow_graph->add_node();
2067 op->set_op(op_name);
2068 op->set_name(src_op.outputs[0]);
2069 CHECK_EQ(src_op.inputs.size(), 2);
2070 for (int i = 0; i < 2; ++i) {
2071 *op->add_input() = src_op.inputs[i];
2072 }
2073 (*op->mutable_attr())["beam_width"].set_i(src_op.beam_width);
2074 (*op->mutable_attr())["top_paths"].set_i(src_op.top_paths);
2075 (*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
2076}
2077
2078void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
2079 const char* op_name, GraphDef* tensorflow_graph) {
2080 tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
2081 unpack_op->set_op(op_name);
2082 unpack_op->set_name(src_op.outputs[0]);
2083 CHECK_EQ(src_op.inputs.size(), 2);
2084 *unpack_op->add_input() = src_op.inputs[0];
2085 const tensorflow::DataType data_type =
2086 GetTensorFlowDataType(model, src_op.inputs[0]);
2087 (*unpack_op->mutable_attr())["T"].set_type(data_type);
2088 (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
2089 (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
2090}
2091
2092void ConvertZerosLikeOperator(const Model& model,
2093 const TensorFlowZerosLikeOperator& src_op,
2094 const char* op_name, GraphDef* tensorflow_graph) {
2095 tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
2096 zeros_like_op->set_op(op_name);
2097 zeros_like_op->set_name(src_op.outputs[0]);
2098 DCHECK_EQ(src_op.inputs.size(), 1);
2099 *zeros_like_op->add_input() = src_op.inputs[0];
2100 const tensorflow::DataType data_type =
2101 GetTensorFlowDataType(model, src_op.inputs[0]);
2102 (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
2103}
2104
2105void ConvertReverseV2Operator(const Model& model,
2106 const ReverseV2Operator& src_op,
2107 const char* op_name, GraphDef* tensorflow_graph) {
2108 tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node();
2109 reverse_v2_op->set_op(op_name);
2110 reverse_v2_op->set_name(src_op.outputs[0]);
2111 DCHECK_EQ(src_op.inputs.size(), 2);
2112 *reverse_v2_op->add_input() = src_op.inputs[0];
2113 *reverse_v2_op->add_input() = src_op.inputs[1];
2114 const tensorflow::DataType data_type =
2115 GetTensorFlowDataType(model, src_op.inputs[0]);
2116 (*reverse_v2_op->mutable_attr())["T"].set_type(data_type);
2117}
2118
2119void ConvertReverseSequenceOperator(const Model& model,
2120 const ReverseSequenceOperator& src_op,
2121 GraphDef* tensorflow_graph) {
2122 tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node();
2123 reverse_seq_op->set_op("ReverseSequence");
2124 reverse_seq_op->set_name(src_op.outputs[0]);
2125 CHECK_EQ(src_op.inputs.size(), 2);
2126 *reverse_seq_op->add_input() = src_op.inputs[0];
2127 *reverse_seq_op->add_input() = src_op.inputs[1];
2128 (*reverse_seq_op->mutable_attr())["seq_dim"].set_i(src_op.seq_dim);
2129 (*reverse_seq_op->mutable_attr())["batch_dim"].set_i(src_op.batch_dim);
2130}
2131
2132void ConvertOperator(const Model& model, const Operator& src_op,
2133 GraphDef* tensorflow_graph) {
2134 if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
2135 LOG(FATAL)
2136 << "Unsupported: the input model has a fused activation function";
2137 }
2138
2139 if (src_op.type == OperatorType::kConv) {
2140 ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op),
2141 tensorflow_graph);
2142 } else if (src_op.type == OperatorType::kDepthwiseConv) {
2143 ConvertDepthwiseConvOperator(
2144 model, static_cast<const DepthwiseConvOperator&>(src_op),
2145 tensorflow_graph);
2146 } else if (src_op.type == OperatorType::kDepthToSpace) {
2147 ConvertDepthToSpaceOperator(
2148 model, static_cast<const DepthToSpaceOperator&>(src_op),
2149 tensorflow_graph);
2150 } else if (src_op.type == OperatorType::kSpaceToDepth) {
2151 ConvertSpaceToDepthOperator(
2152 model, static_cast<const SpaceToDepthOperator&>(src_op),
2153 tensorflow_graph);
2154 } else if (src_op.type == OperatorType::kFullyConnected) {
2155 ConvertFullyConnectedOperator(
2156 model, static_cast<const FullyConnectedOperator&>(src_op),
2157 tensorflow_graph);
2158 } else if (src_op.type == OperatorType::kAdd) {
2159 ConvertAddOperator(model, static_cast<const AddOperator&>(src_op),
2160 tensorflow_graph);
2161 } else if (src_op.type == OperatorType::kAddN) {
2162 ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op),
2163 tensorflow_graph);
2164 } else if (src_op.type == OperatorType::kMul) {
2165 ConvertMulOperator(model, static_cast<const MulOperator&>(src_op),
2166 tensorflow_graph);
2167 } else if (src_op.type == OperatorType::kDiv) {
2168 ConvertDivOperator(model, static_cast<const DivOperator&>(src_op),
2169 tensorflow_graph);
2170 } else if (src_op.type == OperatorType::kRelu) {
2171 ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op),
2172 tensorflow_graph);
2173 } else if (src_op.type == OperatorType::kRelu1) {
2174 ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op),
2175 tensorflow_graph);
2176 } else if (src_op.type == OperatorType::kRelu6) {
2177 ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op),
2178 tensorflow_graph);
2179 } else if (src_op.type == OperatorType::kLog) {
2180 ConvertLogOperator(static_cast<const LogOperator&>(src_op),
2181 tensorflow_graph);
2182 } else if (src_op.type == OperatorType::kLogistic) {
2183 ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op),
2184 tensorflow_graph);
2185 } else if (src_op.type == OperatorType::kTanh) {
2186 ConvertTanhOperator(static_cast<const TanhOperator&>(src_op),
2187 tensorflow_graph);
2188 } else if (src_op.type == OperatorType::kL2Normalization) {
2189 ConvertL2NormalizationOperator(
2190 static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph);
2191 } else if (src_op.type == OperatorType::kSoftmax) {
2192 ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op),
2193 tensorflow_graph);
2194 } else if (src_op.type == OperatorType::kLogSoftmax) {
2195 ConvertLogSoftmaxOperator(model,
2196 static_cast<const LogSoftmaxOperator&>(src_op),
2197 tensorflow_graph);
2198 } else if (src_op.type == OperatorType::kLocalResponseNormalization) {
2199 ConvertLocalResponseNormalizationOperator(
2200 static_cast<const LocalResponseNormalizationOperator&>(src_op),
2201 tensorflow_graph);
2202 } else if (src_op.type == OperatorType::kLstmCell) {
2203 ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op),
2204 tensorflow_graph);
2205 } else if (src_op.type == OperatorType::kMaxPool) {
2206 ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op),
2207 tensorflow_graph);
2208 } else if (src_op.type == OperatorType::kAveragePool) {
2209 ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op),
2210 tensorflow_graph);
2211 } else if (src_op.type == OperatorType::kConcatenation) {
2212 ConvertConcatenationOperator(
2213 model, static_cast<const ConcatenationOperator&>(src_op),
2214 tensorflow_graph);
2215 } else if (src_op.type == OperatorType::kReshape) {
2216 ConvertTensorFlowReshapeOperator(
2217 model, static_cast<const TensorFlowReshapeOperator&>(src_op),
2218 tensorflow_graph);
2219 } else if (src_op.type == OperatorType::kL2Pool) {
2220 ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op),
2221 tensorflow_graph);
2222 } else if (src_op.type == OperatorType::kSquare) {
2223 ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op),
2224 tensorflow_graph);
2225 } else if (src_op.type == OperatorType::kSqrt) {
2226 ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op),
2227 tensorflow_graph);
2228 } else if (src_op.type == OperatorType::kRsqrt) {
2229 ConvertRsqrtOperator(model,
2230 static_cast<const TensorFlowRsqrtOperator&>(src_op),
2231 tensorflow_graph);
2232 } else if (src_op.type == OperatorType::kSplit) {
2233 ConvertSplitOperator(model,
2234 static_cast<const TensorFlowSplitOperator&>(src_op),
2235 tensorflow_graph);
2236 } else if (src_op.type == OperatorType::kSplitV) {
2237 ConvertSplitVOperator(model,
2238 static_cast<const TensorFlowSplitVOperator&>(src_op),
2239 tensorflow_graph);
2240 } else if (src_op.type == OperatorType::kFakeQuant) {
2241 ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op),
2242 tensorflow_graph);
2243 } else if (src_op.type == OperatorType::kCast) {
2244 ConvertCastOperator(model, static_cast<const CastOperator&>(src_op),
2245 tensorflow_graph);
2246 } else if (src_op.type == OperatorType::kFloor) {
2247 ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op),
2248 tensorflow_graph);
2249 } else if (src_op.type == OperatorType::kCeil) {
2250 ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op),
2251 tensorflow_graph);
2252 } else if (src_op.type == OperatorType::kRound) {
2253 ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op),
2254 tensorflow_graph);
2255 } else if (src_op.type == OperatorType::kGather) {
2256 ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op),
2257 tensorflow_graph);
2258 } else if (src_op.type == OperatorType::kResizeBilinear) {
2259 ConvertResizeBilinearOperator(
2260 model, static_cast<const ResizeBilinearOperator&>(src_op),
2261 tensorflow_graph);
2262 } else if (src_op.type == OperatorType::kResizeNearestNeighbor) {
2263 ConvertResizeNearestNeighborOperator(
2264 model, static_cast<const ResizeNearestNeighborOperator&>(src_op),
2265 tensorflow_graph);
2266 } else if (src_op.type == OperatorType::kSpaceToBatchND) {
2267 ConvertSpaceToBatchNDOperator(
2268 model, static_cast<const SpaceToBatchNDOperator&>(src_op),
2269 tensorflow_graph);
2270 } else if (src_op.type == OperatorType::kBatchToSpaceND) {
2271 ConvertBatchToSpaceNDOperator(
2272 model, static_cast<const BatchToSpaceNDOperator&>(src_op),
2273 tensorflow_graph);
2274 } else if (src_op.type == OperatorType::kPad) {
2275 ConvertPadOperator(model, static_cast<const PadOperator&>(src_op),
2276 tensorflow_graph);
2277 } else if (src_op.type == OperatorType::kPadV2) {
2278 ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op),
2279 tensorflow_graph);
2280 } else if (src_op.type == OperatorType::kStridedSlice) {
2281 ConvertStridedSliceOperator(
2282 model, static_cast<const StridedSliceOperator&>(src_op),
2283 tensorflow_graph);
2284 } else if (src_op.type == OperatorType::kMean) {
2285 ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op),
2286 tensorflow_graph, "Mean");
2287 } else if (src_op.type == OperatorType::kSum) {
2288 ConvertReduceOperator(model,
2289 static_cast<const TensorFlowSumOperator&>(src_op),
2290 tensorflow_graph, "Sum");
2291 } else if (src_op.type == OperatorType::kReduceProd) {
2292 ConvertReduceOperator(model,
2293 static_cast<const TensorFlowProdOperator&>(src_op),
2294 tensorflow_graph, "Prod");
2295 } else if (src_op.type == OperatorType::kReduceMin) {
2296 ConvertReduceOperator(model,
2297 static_cast<const TensorFlowMinOperator&>(src_op),
2298 tensorflow_graph, "Min");
2299 } else if (src_op.type == OperatorType::kReduceMax) {
2300 ConvertReduceOperator(model,
2301 static_cast<const TensorFlowMaxOperator&>(src_op),
2302 tensorflow_graph, "Max");
2303 } else if (src_op.type == OperatorType::kSub) {
2304 ConvertSubOperator(model, static_cast<const SubOperator&>(src_op),
2305 tensorflow_graph);
2306 } else if (src_op.type == OperatorType::kMinimum) {
2307 ConvertTensorFlowMinimumOperator(
2308 model, static_cast<const TensorFlowMinimumOperator&>(src_op),
2309 tensorflow_graph);
2310 } else if (src_op.type == OperatorType::kMaximum) {
2311 ConvertTensorFlowMaximumOperator(
2312 model, static_cast<const TensorFlowMaximumOperator&>(src_op),
2313 tensorflow_graph);
2314 } else if (src_op.type == OperatorType::kSqueeze) {
2315 ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op),
2316 tensorflow_graph);
2317 } else if (src_op.type == OperatorType::kSlice) {
2318 ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op),
2319 tensorflow_graph);
2320 } else if (src_op.type == OperatorType::kArgMax) {
2321 ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op),
2322 tensorflow_graph);
2323 } else if (src_op.type == OperatorType::kArgMin) {
2324 ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op),
2325 tensorflow_graph);
2326 } else if (src_op.type == OperatorType::kTopK_V2) {
2327 ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op),
2328 tensorflow_graph);
2329 } else if (src_op.type == OperatorType::kTranspose) {
2330 ConvertTransposeOperator(
2331 model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph);
2332 } else if (src_op.type == OperatorType::kShape) {
2333 ConvertTensorFlowShapeOperator(
2334 model, static_cast<const TensorFlowShapeOperator&>(src_op),
2335 tensorflow_graph);
2336 } else if (src_op.type == OperatorType::kRank) {
2337 ConvertRankOperator(model,
2338 static_cast<const TensorFlowRankOperator&>(src_op),
2339 tensorflow_graph);
2340 } else if (src_op.type == OperatorType::kRange) {
2341 ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op),
2342 tensorflow_graph);
2343 } else if (src_op.type == OperatorType::kPack) {
2344 ConvertPackOperator(model, static_cast<const PackOperator&>(src_op),
2345 tensorflow_graph);
2346 } else if (src_op.type == OperatorType::kFill) {
2347 ConvertFillOperator(model, static_cast<const FillOperator&>(src_op),
2348 tensorflow_graph);
2349 } else if (src_op.type == OperatorType::kFloorDiv) {
2350 ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op),
2351 tensorflow_graph);
2352 } else if (src_op.type == OperatorType::kFloorMod) {
2353 ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op),
2354 tensorflow_graph);
2355 } else if (src_op.type == OperatorType::kExpandDims) {
2356 ConvertExpandDimsOperator(model,
2357 static_cast<const ExpandDimsOperator&>(src_op),
2358 tensorflow_graph);
2359 } else if (src_op.type == OperatorType::kTransposeConv) {
2360 ConvertTransposeConvOperator(
2361 model, static_cast<const TransposeConvOperator&>(src_op),
2362 tensorflow_graph);
2363 } else if (src_op.type == OperatorType::kRandomUniform) {
2364 ConvertRandomUniformOperator(
2365 model, static_cast<const RandomUniformOperator&>(src_op),
2366 tensorflow_graph);
2367 } else if (src_op.type == OperatorType::kEqual) {
2368 ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph);
2369 } else if (src_op.type == OperatorType::kNotEqual) {
2370 ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph);
2371 } else if (src_op.type == OperatorType::kGreater) {
2372 ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph);
2373 } else if (src_op.type == OperatorType::kGreaterEqual) {
2374 ConvertComparisonOperator(model, src_op, "GreaterEqual", tensorflow_graph);
2375 } else if (src_op.type == OperatorType::kLess) {
2376 ConvertComparisonOperator(model, src_op, "Less", tensorflow_graph);
2377 } else if (src_op.type == OperatorType::kLessEqual) {
2378 ConvertComparisonOperator(model, src_op, "LessEqual", tensorflow_graph);
2379 } else if (src_op.type == OperatorType::kSelect) {
2380 ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op),
2381 tensorflow_graph);
2382 } else if (src_op.type == OperatorType::kTile) {
2383 ConvertTileOperator(model,
2384 static_cast<const TensorFlowTileOperator&>(src_op),
2385 tensorflow_graph);
2386 } else if (src_op.type == OperatorType::kPow) {
2387 ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
2388 tensorflow_graph);
2389 } else if (src_op.type == OperatorType::kAny) {
2390 ConvertReduceOperator(model,
2391 static_cast<const TensorFlowAnyOperator&>(src_op),
2392 tensorflow_graph, "Any");
2393 } else if (src_op.type == OperatorType::kLogicalAnd) {
2394 ConvertLogicalAndOperator(model,
2395 static_cast<const LogicalAndOperator&>(src_op),
2396 tensorflow_graph);
2397 } else if (src_op.type == OperatorType::kLogicalNot) {
2398 ConvertLogicalNotOperator(model,
2399 static_cast<const LogicalNotOperator&>(src_op),
2400 tensorflow_graph);
2401 } else if (src_op.type == OperatorType::kOneHot) {
2402 ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
2403 tensorflow_graph);
2404 } else if (src_op.type == OperatorType::kLogicalOr) {
2405 ConvertLogicalOrOperator(model,
2406 static_cast<const LogicalOrOperator&>(src_op),
2407 "LogicalOr", tensorflow_graph);
2408 } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) {
2409 ConvertCTCBeamSearchDecoderOperator(
2410 model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
2411 "CTCBeamSearchDecoder", tensorflow_graph);
2412 } else if (src_op.type == OperatorType::kUnpack) {
2413 ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
2414 "Unpack", tensorflow_graph);
2415 } else if (src_op.type == OperatorType::kZerosLike) {
2416 ConvertZerosLikeOperator(
2417 model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
2418 "ZerosLike", tensorflow_graph);
2419 } else if (src_op.type == OperatorType::kReverseV2) {
2420 ConvertReverseV2Operator(model,
2421 static_cast<const ReverseV2Operator&>(src_op),
2422 "Reverse_V2", tensorflow_graph);
2423 } else if (src_op.type == OperatorType::kReverseSequence) {
2424 ConvertReverseSequenceOperator(
2425 model, static_cast<const ReverseSequenceOperator&>(src_op),
2426 tensorflow_graph);
2427 } else {
2428 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
2429 }
2430}
2431
2432void AddPlaceholder(const std::string& name, ArrayDataType type,
2433 GraphDef* tensorflow_graph) {
2434 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2435 placeholder->set_op("Placeholder");
2436 switch (type) {
2437 case ArrayDataType::kBool:
2438 (*placeholder->mutable_attr())["dtype"].set_type(DT_BOOL);
2439 break;
2440 case ArrayDataType::kFloat:
2441 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2442 break;
2443 case ArrayDataType::kUint8:
2444 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT8);
2445 break;
2446 case ArrayDataType::kInt32:
2447 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT32);
2448 break;
2449 case ArrayDataType::kUint32:
2450 (*placeholder->mutable_attr())["dtype"].set_type(DT_UINT32);
2451 break;
2452 case ArrayDataType::kInt64:
2453 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT64);
2454 break;
2455 case ArrayDataType::kInt16:
2456 (*placeholder->mutable_attr())["dtype"].set_type(DT_INT16);
2457 break;
2458 case ArrayDataType::kComplex64:
2459 (*placeholder->mutable_attr())["dtype"].set_type(DT_COMPLEX64);
2460 break;
2461 default:
2462 LOG(FATAL) << "Unexpected data type in array \"" << name << "\"";
2463 }
2464 placeholder->set_name(name);
2465}
2466
2467void AddPlaceholderForRNNState(const Model& model, const std::string& name,
2468 int size, GraphDef* tensorflow_graph) {
2469 tensorflow::NodeDef* placeholder = tensorflow_graph->add_node();
2470 placeholder->set_op("Placeholder");
2471 placeholder->set_name(name);
2472 (*placeholder->mutable_attr())["dtype"].set_type(DT_FLOAT);
2473
2474 auto* shape = (*placeholder->mutable_attr())["shape"].mutable_shape();
2475 const auto& state_array = model.GetArray(name);
2476 if (state_array.has_shape()) {
2477 const auto& state_shape = state_array.shape();
2478 const int kDims = state_shape.dimensions_count();
2479 for (int i = 0; i < kDims; ++i) {
2480 shape->add_dim()->set_size(state_shape.dims(i));
2481 }
2482 } else {
2483 shape->add_dim()->set_size(1);
2484 shape->add_dim()->set_size(size);
2485 }
2486}
2487
2488void ExportTensorFlowGraphDefImplementation(const Model& model,
2489 GraphDef* tensorflow_graph) {
2490 for (const auto& input_array : model.flags.input_arrays()) {
2491 AddPlaceholder(input_array.name(),
2492 model.GetArray(input_array.name()).data_type,
2493 tensorflow_graph);
2494 }
2495 for (const auto& rnn_state : model.flags.rnn_states()) {
2496 AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(),
2497 tensorflow_graph);
2498 }
2499 for (const auto& op : model.operators) {
2500 ConvertOperator(model, *op, tensorflow_graph);
2501 }
2502 // Generically export arrays that haven't been exported already
2503 // by the above operators export. It's important that this comes
2504 // after, as some operators need to export arrays that they reference
2505 // in a specific way, rather than in the generic way done below.
2506 for (const auto& array_pair : model.GetArrayMap()) {
2507 const std::string& array_name = array_pair.first;
2508 const auto& array = *array_pair.second;
2509 if (array.buffer) {
2510 switch (array.data_type) {
2511 case ArrayDataType::kBool:
2512 ConvertBoolTensorConst(model, array_name, tensorflow_graph);
2513 break;
2514 case ArrayDataType::kFloat:
2515 ConvertFloatTensorConst(model, array_name, tensorflow_graph);
2516 break;
2517 case ArrayDataType::kInt32:
2518 ConvertIntTensorConst(model, array_name, tensorflow_graph);
2519 break;
2520 case ArrayDataType::kComplex64:
2521 ConvertComplex64TensorConst(model, array_name, tensorflow_graph);
2522 break;
2523 default:
2524 break;
2525 }
2526 }
2527 }
2528}
2529} // namespace
2530
2531void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) {
2532 for (const auto& array_kv : model->GetArrayMap()) {
2533 const std::string& array_name = array_kv.first;
2534 Array& array = *array_kv.second;
2535 if (!array.buffer || !array.minmax) {
2536 continue;
2537 }
2538 const std::string& wrapped_array_name =
2539 AvailableArrayName(*model, array_name + "/data");
2540 Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name);
2541 wrapped_array.data_type = array.data_type;
2542 wrapped_array.copy_shape(array.shape());
2543 wrapped_array.buffer = std::move(array.buffer);
2544 FakeQuantOperator* fakequant_op = new FakeQuantOperator;
2545 fakequant_op->inputs = {wrapped_array_name};
2546 fakequant_op->outputs = {array_name};
2547 fakequant_op->minmax = std::make_unique<MinMax>();
2548 *fakequant_op->minmax = *array.minmax;
2549 const auto& it = FindOpWithInput(*model, array_name);
2550 model->operators.emplace(it, fakequant_op);
2551 }
2552 CheckInvariants(*model);
2553}
2554
2555void ExportTensorFlowGraphDef(const Model& model,
2556 std::string* output_file_contents) {
2557 CHECK(output_file_contents->empty());
2558 GraphDef tensorflow_graph;
2559 ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph);
2560 LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT", tensorflow_graph);
2561 CHECK(tensorflow_graph.SerializeToString(output_file_contents));
2562}
2563} // namespace toco
2564