1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <algorithm> |
16 | #include <memory> |
17 | #include <string> |
18 | #include <unordered_map> |
19 | #include <utility> |
20 | #include <vector> |
21 | |
22 | #include "google/protobuf/map.h" |
23 | #include "google/protobuf/text_format.h" |
24 | #include "absl/memory/memory.h" |
25 | #include "absl/strings/string_view.h" |
26 | #include "tensorflow/core/framework/attr_value.pb.h" |
27 | #include "tensorflow/core/framework/graph.pb.h" |
28 | #include "tensorflow/core/framework/node_def.pb.h" |
29 | #include "tensorflow/core/framework/tensor.pb.h" |
30 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
31 | #include "tensorflow/core/framework/types.pb.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/lite/toco/model.h" |
34 | #include "tensorflow/lite/toco/model_flags.pb.h" |
35 | #include "tensorflow/lite/toco/runtime/types.h" |
36 | #include "tensorflow/lite/toco/tensorflow_util.h" |
37 | #include "tensorflow/lite/toco/tooling_util.h" |
38 | |
39 | using tensorflow::DT_BOOL; |
40 | using tensorflow::DT_COMPLEX64; |
41 | using tensorflow::DT_FLOAT; |
42 | using tensorflow::DT_INT16; |
43 | using tensorflow::DT_INT32; |
44 | using tensorflow::DT_INT64; |
45 | using tensorflow::DT_UINT32; |
46 | using tensorflow::DT_UINT8; |
47 | using tensorflow::GraphDef; |
48 | using tensorflow::TensorProto; |
49 | |
50 | namespace toco { |
51 | namespace { |
52 | |
53 | tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type, |
54 | const std::string& error_location) { |
55 | switch (data_type) { |
56 | case ArrayDataType::kBool: |
57 | return tensorflow::DT_BOOL; |
58 | case ArrayDataType::kFloat: |
59 | return tensorflow::DT_FLOAT; |
60 | case ArrayDataType::kUint8: |
61 | return tensorflow::DT_UINT8; |
62 | case ArrayDataType::kInt16: |
63 | return tensorflow::DT_INT16; |
64 | case ArrayDataType::kUint16: |
65 | return tensorflow::DT_UINT16; |
66 | case ArrayDataType::kInt32: |
67 | return tensorflow::DT_INT32; |
68 | case ArrayDataType::kUint32: |
69 | return tensorflow::DT_UINT32; |
70 | case ArrayDataType::kInt64: |
71 | return tensorflow::DT_INT64; |
72 | case ArrayDataType::kString: |
73 | return tensorflow::DT_STRING; |
74 | case ArrayDataType::kComplex64: |
75 | return tensorflow::DT_COMPLEX64; |
76 | default: |
77 | case ArrayDataType::kNone: |
78 | LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type) |
79 | << "' in " << error_location; |
80 | return tensorflow::DT_INVALID; |
81 | } |
82 | } |
83 | |
84 | tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type, |
85 | const std::string& op_name) { |
86 | return GetTensorFlowDataType(data_type, "op '" + op_name + "'" ); |
87 | } |
88 | |
89 | tensorflow::DataType GetTensorFlowDataType(const Model& model, |
90 | const std::string& array_name) { |
91 | return GetTensorFlowDataType(model.GetArray(array_name).data_type, |
92 | "array '" + array_name + "'" ); |
93 | } |
94 | |
95 | // TensorFlow sometimes forbids what it calls "legacy scalars", |
96 | // which are 1-D shapes where the unique shape size is 1. |
97 | // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. |
98 | // For that reason, we generally avoid creating legacy scalars, |
99 | // by detecting the case where a 1-D shape would be of size 1 and |
100 | // replacing that by a 0-D shape. |
101 | // However, there is a special circumstance where we must not do that |
102 | // and must unconditionally create a 1-D shape even if it is going to |
103 | // be of size 1: that is the case of bias vectors, with BiasAdd nodes. |
104 | // Indeed, TensorFlow requires bias vectors to be 1-D; in the case of |
105 | // a depth of 1, that would be a legacy scalar, so in that case we |
106 | // must go ahead and keep the shape 1-D, letting it be a legacy scalar. |
107 | enum class LegacyScalarPolicy { kAvoidLegacyScalars, kDoCreateLegacyScalars }; |
108 | |
109 | void ExportFloatArray(const Shape& input_shape, const float* input_data, |
110 | TensorProto* output_tensor, |
111 | LegacyScalarPolicy legacy_scalar_policy) { |
112 | output_tensor->set_dtype(DT_FLOAT); |
113 | const int input_flat_size = RequiredBufferSizeForShape(input_shape); |
114 | auto* shape = output_tensor->mutable_tensor_shape(); |
115 | |
116 | const int kDims = input_shape.dimensions_count(); |
117 | if (legacy_scalar_policy == LegacyScalarPolicy::kDoCreateLegacyScalars || |
118 | kDims > 1 || (kDims == 1 && input_shape.dims(0) > 1)) { |
119 | for (int i = 0; i < kDims; ++i) { |
120 | shape->add_dim()->set_size(input_shape.dims(i)); |
121 | } |
122 | } |
123 | output_tensor->set_tensor_content( |
124 | std::string(reinterpret_cast<const char*>(input_data), |
125 | sizeof(*input_data) * input_flat_size)); |
126 | } |
127 | |
128 | void ExportFloatArray(AxesOrder input_axes_order, const Shape& input_shape, |
129 | const float* input_data, AxesOrder output_axes_order, |
130 | TensorProto* output_tensor, |
131 | LegacyScalarPolicy legacy_scalar_policy) { |
132 | CHECK_EQ(AxesCount(output_axes_order), AxesCount(input_axes_order)); |
133 | output_tensor->set_dtype(DT_FLOAT); |
134 | CHECK_EQ(input_shape.dimensions_count(), AxesCount(input_axes_order)); |
135 | const int input_flat_size = RequiredBufferSizeForShape(input_shape); |
136 | |
137 | Shape shuffled_shape; |
138 | ShuffleDims(input_shape, input_axes_order, output_axes_order, |
139 | &shuffled_shape); |
140 | std::vector<float> shuffled_data(input_flat_size); |
141 | ShuffleArray(input_shape, input_axes_order, output_axes_order, shuffled_shape, |
142 | input_data, shuffled_data.data()); |
143 | |
144 | ExportFloatArray(shuffled_shape, shuffled_data.data(), output_tensor, |
145 | legacy_scalar_policy); |
146 | } |
147 | |
148 | bool HasAlreadyExportedConst(const std::string& name, |
149 | const GraphDef& tensorflow_graph) { |
150 | for (const auto& node : tensorflow_graph.node()) { |
151 | if (node.op() == "Const" && node.name() == name) { |
152 | return true; |
153 | } |
154 | } |
155 | return false; |
156 | } |
157 | |
158 | void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape, |
159 | const float* input_data, |
160 | AxesOrder input_axes_order, |
161 | AxesOrder output_axes_order, |
162 | GraphDef* tensorflow_graph, |
163 | LegacyScalarPolicy legacy_scalar_policy) { |
164 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
165 | return; |
166 | } |
167 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
168 | const_op->set_op("Const" ); |
169 | const_op->set_name(name); |
170 | (*const_op->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
171 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
172 | ExportFloatArray(input_axes_order, input_shape, input_data, output_axes_order, |
173 | tensor, legacy_scalar_policy); |
174 | } |
175 | |
176 | void ConvertFloatTensorConst(const std::string& name, const Shape& input_shape, |
177 | const float* input_data, |
178 | AxesOrder input_axes_order, |
179 | AxesOrder output_axes_order, |
180 | GraphDef* tensorflow_graph) { |
181 | ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order, |
182 | output_axes_order, tensorflow_graph, |
183 | LegacyScalarPolicy::kAvoidLegacyScalars); |
184 | } |
185 | |
186 | void ConvertFloatTensorConst(const Model& model, const std::string& name, |
187 | AxesOrder input_axes_order, |
188 | AxesOrder output_axes_order, |
189 | GraphDef* tensorflow_graph) { |
190 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
191 | return; |
192 | } |
193 | CHECK(model.HasArray(name)); |
194 | const auto& input_array = model.GetArray(name); |
195 | const auto& input_shape = input_array.shape(); |
196 | CHECK(input_array.buffer); |
197 | CHECK(input_array.buffer->type == ArrayDataType::kFloat); |
198 | const float* input_data = |
199 | input_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
200 | ConvertFloatTensorConst(name, input_shape, input_data, input_axes_order, |
201 | output_axes_order, tensorflow_graph); |
202 | } |
203 | |
204 | void ConvertFloatTensorConst(const Model& model, const std::string& name, |
205 | GraphDef* tensorflow_graph) { |
206 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
207 | return; |
208 | } |
209 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
210 | const_op->set_op("Const" ); |
211 | const_op->set_name(name); |
212 | (*const_op->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
213 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
214 | CHECK(model.HasArray(name)); |
215 | const auto& input_array = model.GetArray(name); |
216 | const auto& input_shape = input_array.shape(); |
217 | CHECK(input_array.buffer); |
218 | CHECK(input_array.buffer->type == ArrayDataType::kFloat); |
219 | const float* input_data = |
220 | input_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
221 | ExportFloatArray(input_shape, input_data, tensor, |
222 | LegacyScalarPolicy::kAvoidLegacyScalars); |
223 | } |
224 | |
225 | void ConvertBoolTensorConst(const Model& model, const std::string& name, |
226 | GraphDef* tensorflow_graph) { |
227 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
228 | return; |
229 | } |
230 | CHECK(model.HasArray(name)); |
231 | const auto& array = model.GetArray(name); |
232 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
233 | const_op->set_op("Const" ); |
234 | const_op->set_name(name); |
235 | (*const_op->mutable_attr())["dtype" ].set_type(DT_BOOL); |
236 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
237 | tensor->set_dtype(DT_BOOL); |
238 | const auto& data = array.GetBuffer<ArrayDataType::kBool>().data; |
239 | for (auto index : data) { |
240 | tensor->add_bool_val(index); |
241 | } |
242 | const auto& array_shape = array.shape(); |
243 | auto* shape = tensor->mutable_tensor_shape(); |
244 | for (int i = 0; i < array_shape.dimensions_count(); i++) { |
245 | shape->add_dim()->set_size(array_shape.dims(i)); |
246 | } |
247 | } |
248 | |
249 | void ConvertIntTensorConst(const Model& model, const std::string& name, |
250 | GraphDef* tensorflow_graph) { |
251 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
252 | return; |
253 | } |
254 | CHECK(model.HasArray(name)); |
255 | const auto& array = model.GetArray(name); |
256 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
257 | const_op->set_op("Const" ); |
258 | const_op->set_name(name); |
259 | (*const_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
260 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
261 | tensor->set_dtype(DT_INT32); |
262 | const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data; |
263 | for (auto index : data) { |
264 | tensor->add_int_val(index); |
265 | } |
266 | const auto& array_shape = array.shape(); |
267 | auto* shape = tensor->mutable_tensor_shape(); |
268 | for (int i = 0; i < array_shape.dimensions_count(); i++) { |
269 | shape->add_dim()->set_size(array_shape.dims(i)); |
270 | } |
271 | } |
272 | |
273 | void CreateIntTensorConst(const std::string& name, |
274 | const std::vector<int32>& data, |
275 | const std::vector<int32>& shape, |
276 | GraphDef* tensorflow_graph) { |
277 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
278 | return; |
279 | } |
280 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
281 | const_op->set_op("Const" ); |
282 | const_op->set_name(name); |
283 | (*const_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
284 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
285 | tensor->set_dtype(DT_INT32); |
286 | for (auto index : data) { |
287 | tensor->add_int_val(index); |
288 | } |
289 | auto* tensor_shape = tensor->mutable_tensor_shape(); |
290 | int num_elements = 1; |
291 | for (int size : shape) { |
292 | tensor_shape->add_dim()->set_size(size); |
293 | num_elements *= size; |
294 | } |
295 | CHECK_EQ(num_elements, data.size()); |
296 | } |
297 | |
298 | void ConvertComplex64TensorConst(const Model& model, const std::string& name, |
299 | GraphDef* tensorflow_graph) { |
300 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
301 | return; |
302 | } |
303 | CHECK(model.HasArray(name)); |
304 | const auto& array = model.GetArray(name); |
305 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
306 | const_op->set_op("Const" ); |
307 | const_op->set_name(name); |
308 | (*const_op->mutable_attr())["dtype" ].set_type(DT_COMPLEX64); |
309 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
310 | tensor->set_dtype(DT_COMPLEX64); |
311 | const auto& data = array.GetBuffer<ArrayDataType::kComplex64>().data; |
312 | for (auto index : data) { |
313 | tensor->add_scomplex_val(std::real(index)); |
314 | tensor->add_scomplex_val(std::imag(index)); |
315 | } |
316 | const auto& array_shape = array.shape(); |
317 | auto* shape = tensor->mutable_tensor_shape(); |
318 | for (int i = 0; i < array_shape.dimensions_count(); i++) { |
319 | shape->add_dim()->set_size(array_shape.dims(i)); |
320 | } |
321 | } |
322 | |
323 | void CreateMatrixShapeTensorConst(const std::string& name, int rows, int cols, |
324 | GraphDef* tensorflow_graph) { |
325 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
326 | return; |
327 | } |
328 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
329 | const_op->set_op("Const" ); |
330 | const_op->set_name(name); |
331 | (*const_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
332 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
333 | tensor->set_dtype(DT_INT32); |
334 | const int32 data[2] = {cols, rows}; |
335 | tensor->set_tensor_content( |
336 | std::string(reinterpret_cast<const char*>(data), sizeof(data))); |
337 | auto* shape = tensor->mutable_tensor_shape(); |
338 | shape->add_dim()->set_size(2); |
339 | } |
340 | |
341 | void CreateDummyConcatDimTensorConst(const std::string& name, int dim, |
342 | GraphDef* tensorflow_graph) { |
343 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
344 | return; |
345 | } |
346 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
347 | const_op->set_op("Const" ); |
348 | const_op->set_name(name); |
349 | (*const_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
350 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
351 | tensor->set_dtype(DT_INT32); |
352 | tensor->add_int_val(dim); |
353 | } |
354 | |
355 | void CreateReshapeShapeTensorConst(const std::string& name, |
356 | const std::vector<int32>& shape, |
357 | GraphDef* tensorflow_graph) { |
358 | if (HasAlreadyExportedConst(name, *tensorflow_graph)) { |
359 | return; |
360 | } |
361 | tensorflow::NodeDef* const_op = tensorflow_graph->add_node(); |
362 | const_op->set_op("Const" ); |
363 | const_op->set_name(name); |
364 | (*const_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
365 | auto* tensor = (*const_op->mutable_attr())["value" ].mutable_tensor(); |
366 | tensor->set_dtype(DT_INT32); |
367 | for (auto s : shape) { |
368 | tensor->add_int_val(s); |
369 | } |
370 | // TensorFlow sometimes forbids what it calls "legacy scalars", |
371 | // which are shapes of size 1 where the unique shape size is 1. |
372 | // See OpKernel::IsLegacyScalar and OpKernel::allow_legacy_scalars. |
373 | if (shape.size() > 1) { |
374 | auto* tensor_shape = tensor->mutable_tensor_shape(); |
375 | tensor_shape->add_dim()->set_size(shape.size()); |
376 | } |
377 | } |
378 | |
379 | std::string WalkUpToConstantArray(const Model& model, const std::string& name) { |
380 | const Array& original_array = model.GetArray(name); |
381 | if (original_array.buffer) { |
382 | return name; |
383 | } |
384 | const auto* op = GetOpWithOutput(model, name); |
385 | CHECK(op); |
386 | CHECK(op->type == OperatorType::kFakeQuant); |
387 | const std::string& input_of_fakequant_name = op->inputs[0]; |
388 | const Array& input_of_fakequant = model.GetArray(input_of_fakequant_name); |
389 | CHECK(input_of_fakequant.buffer); |
390 | return input_of_fakequant_name; |
391 | } |
392 | |
393 | void ConvertConvOperator(const Model& model, const ConvOperator& src_op, |
394 | GraphDef* tensorflow_graph) { |
395 | const bool has_bias = src_op.inputs.size() >= 3; |
396 | std::string conv_output = src_op.outputs[0]; |
397 | if (has_bias) { |
398 | conv_output += "/conv" ; |
399 | } |
400 | |
401 | tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); |
402 | conv2d_op->set_op("Conv2D" ); |
403 | conv2d_op->set_name(conv_output); |
404 | *conv2d_op->add_input() = src_op.inputs[0]; |
405 | *conv2d_op->add_input() = src_op.inputs[1]; |
406 | (*conv2d_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
407 | const std::string& weights_array_name = |
408 | WalkUpToConstantArray(model, src_op.inputs[1]); |
409 | const auto& weights_array = model.GetArray(weights_array_name); |
410 | CHECK(weights_array.buffer->type == ArrayDataType::kFloat); |
411 | ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, |
412 | AxesOrder::kHWIO, tensorflow_graph); |
413 | auto& strides = (*conv2d_op->mutable_attr())["strides" ]; |
414 | strides.mutable_list()->add_i(1); |
415 | strides.mutable_list()->add_i(src_op.stride_height); |
416 | strides.mutable_list()->add_i(src_op.stride_width); |
417 | strides.mutable_list()->add_i(1); |
418 | if ((src_op.dilation_width_factor != 1) || |
419 | (src_op.dilation_height_factor != 1)) { |
420 | auto& dilations = (*conv2d_op->mutable_attr())["dilations" ]; |
421 | dilations.mutable_list()->add_i(1); |
422 | dilations.mutable_list()->add_i(src_op.dilation_height_factor); |
423 | dilations.mutable_list()->add_i(src_op.dilation_width_factor); |
424 | dilations.mutable_list()->add_i(1); |
425 | } |
426 | std::string padding; |
427 | if (src_op.padding.type == PaddingType::kSame) { |
428 | padding = "SAME" ; |
429 | } else if (src_op.padding.type == PaddingType::kValid) { |
430 | padding = "VALID" ; |
431 | } else { |
432 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
433 | } |
434 | (*conv2d_op->mutable_attr())["padding" ].set_s(padding); |
435 | |
436 | if (has_bias) { |
437 | tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); |
438 | biasadd_op->set_op("BiasAdd" ); |
439 | biasadd_op->set_name(src_op.outputs[0]); |
440 | biasadd_op->add_input(conv_output); |
441 | biasadd_op->add_input(src_op.inputs[2]); |
442 | (*biasadd_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
443 | CHECK(model.HasArray(src_op.inputs[2])); |
444 | const std::string& bias_array_name = |
445 | WalkUpToConstantArray(model, src_op.inputs[2]); |
446 | const auto& bias_array = model.GetArray(bias_array_name); |
447 | // TODO(b/62904716) Bias arrays should be 1-D, and used directly. |
448 | Shape bias_shape_1d = bias_array.shape(); |
449 | UnextendShape(&bias_shape_1d, 1); |
450 | CHECK(bias_array.buffer->type == ArrayDataType::kFloat); |
451 | const float* bias_data = |
452 | bias_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
453 | ConvertFloatTensorConst(bias_array_name, bias_shape_1d, bias_data, |
454 | AxesOrder::kOneAxis, AxesOrder::kOneAxis, |
455 | tensorflow_graph, |
456 | LegacyScalarPolicy::kDoCreateLegacyScalars); |
457 | } |
458 | } |
459 | |
460 | void ConvertDepthwiseConvOperator(const Model& model, |
461 | const DepthwiseConvOperator& src_op, |
462 | GraphDef* tensorflow_graph) { |
463 | const bool has_bias = src_op.inputs.size() >= 3; |
464 | std::string conv_output = src_op.outputs[0]; |
465 | if (has_bias) { |
466 | conv_output += "/conv" ; |
467 | } |
468 | |
469 | tensorflow::NodeDef* dc2d_op = tensorflow_graph->add_node(); |
470 | dc2d_op->set_op("DepthwiseConv2dNative" ); |
471 | dc2d_op->set_name(conv_output); |
472 | *dc2d_op->add_input() = src_op.inputs[0]; |
473 | *dc2d_op->add_input() = src_op.inputs[1]; |
474 | (*dc2d_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
475 | |
476 | // Our internal DepthwiseConv weights are 1 x H x W x OutputDepth. |
477 | // We need to convert that to H x W x InputDepth x Multiplier. |
478 | // That's only a matter of constructing a Dims object; the actual |
479 | // array layout is the same. |
480 | CHECK(model.HasArray(src_op.inputs[1])); |
481 | const std::string& src_weights_name = |
482 | WalkUpToConstantArray(model, src_op.inputs[1]); |
483 | const auto& src_weights_array = model.GetArray(src_weights_name); |
484 | const auto& src_weights_shape = src_weights_array.shape(); |
485 | CHECK_EQ(src_weights_shape.dimensions_count(), 4); |
486 | const Shape dst_weights_shape = |
487 | Shape({src_weights_shape.dims(1), src_weights_shape.dims(2), |
488 | src_weights_shape.dims(3) / src_op.depth_multiplier, |
489 | src_op.depth_multiplier}); |
490 | CHECK_EQ(src_weights_shape.dims(3) % src_op.depth_multiplier, 0); |
491 | CHECK(dst_weights_shape.dims(2) * dst_weights_shape.dims(3) == |
492 | src_weights_shape.dims(3)); |
493 | CHECK_EQ(src_weights_shape.dims(0), 1); |
494 | |
495 | CHECK(src_weights_array.buffer->type == ArrayDataType::kFloat); |
496 | const float* src_weights_data = |
497 | src_weights_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
498 | ConvertFloatTensorConst(src_weights_name, dst_weights_shape, src_weights_data, |
499 | AxesOrder::kHWIM, AxesOrder::kHWIM, tensorflow_graph); |
500 | |
501 | auto& strides = (*dc2d_op->mutable_attr())["strides" ]; |
502 | strides.mutable_list()->add_i(1); |
503 | strides.mutable_list()->add_i(src_op.stride_height); |
504 | strides.mutable_list()->add_i(src_op.stride_width); |
505 | strides.mutable_list()->add_i(1); |
506 | // TODO(b/116063589): To return a working TF GraphDef, we should be returning |
507 | // the correct SpaceToBatchNd and BatchToSpaceND operation before and after |
508 | // the conv since TF doesn't support dilations. |
509 | if ((src_op.dilation_width_factor != 1) || |
510 | (src_op.dilation_height_factor != 1)) { |
511 | auto& dilations = (*dc2d_op->mutable_attr())["dilations" ]; |
512 | dilations.mutable_list()->add_i(1); |
513 | dilations.mutable_list()->add_i(src_op.dilation_height_factor); |
514 | dilations.mutable_list()->add_i(src_op.dilation_width_factor); |
515 | dilations.mutable_list()->add_i(1); |
516 | } |
517 | std::string padding; |
518 | if (src_op.padding.type == PaddingType::kSame) { |
519 | padding = "SAME" ; |
520 | } else if (src_op.padding.type == PaddingType::kValid) { |
521 | padding = "VALID" ; |
522 | } else { |
523 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
524 | } |
525 | (*dc2d_op->mutable_attr())["padding" ].set_s(padding); |
526 | |
527 | if (has_bias) { |
528 | tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); |
529 | biasadd_op->set_op("BiasAdd" ); |
530 | biasadd_op->set_name(src_op.outputs[0]); |
531 | biasadd_op->add_input(conv_output); |
532 | biasadd_op->add_input(src_op.inputs[2]); |
533 | (*biasadd_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
534 | CHECK(model.HasArray(src_op.inputs[2])); |
535 | const std::string& bias_name = |
536 | WalkUpToConstantArray(model, src_op.inputs[2]); |
537 | const auto& bias_array = model.GetArray(bias_name); |
538 | // TODO(b/62904716) Bias arrays should be 1-D, and used directly. |
539 | Shape bias_shape_1d = bias_array.shape(); |
540 | UnextendShape(&bias_shape_1d, 1); |
541 | CHECK(bias_array.buffer->type == ArrayDataType::kFloat); |
542 | const float* bias_data = |
543 | bias_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
544 | ConvertFloatTensorConst(bias_name, bias_shape_1d, bias_data, |
545 | AxesOrder::kOneAxis, AxesOrder::kOneAxis, |
546 | tensorflow_graph, |
547 | LegacyScalarPolicy::kDoCreateLegacyScalars); |
548 | } |
549 | } |
550 | |
551 | void ConvertTransposeConvOperator(const Model& model, |
552 | const TransposeConvOperator& src_op, |
553 | GraphDef* tensorflow_graph) { |
554 | tensorflow::NodeDef* conv2d_op = tensorflow_graph->add_node(); |
555 | conv2d_op->set_op("Conv2DBackpropInput" ); |
556 | conv2d_op->set_name(src_op.outputs[0]); |
557 | *conv2d_op->add_input() = src_op.inputs[0]; |
558 | *conv2d_op->add_input() = src_op.inputs[1]; |
559 | *conv2d_op->add_input() = src_op.inputs[2]; |
560 | (*conv2d_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
561 | const std::string& weights_array_name = WalkUpToConstantArray( |
562 | model, src_op.inputs[TransposeConvOperator::WEIGHTS]); |
563 | const auto& weights_array = model.GetArray(weights_array_name); |
564 | CHECK(weights_array.buffer->type == ArrayDataType::kFloat); |
565 | ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, |
566 | AxesOrder::kHWOI, tensorflow_graph); |
567 | auto& strides = (*conv2d_op->mutable_attr())["strides" ]; |
568 | strides.mutable_list()->add_i(1); |
569 | strides.mutable_list()->add_i(src_op.stride_height); |
570 | strides.mutable_list()->add_i(src_op.stride_width); |
571 | strides.mutable_list()->add_i(1); |
572 | std::string padding; |
573 | if (src_op.padding.type == PaddingType::kSame) { |
574 | padding = "SAME" ; |
575 | } else if (src_op.padding.type == PaddingType::kValid) { |
576 | padding = "VALID" ; |
577 | } else { |
578 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
579 | } |
580 | (*conv2d_op->mutable_attr())["padding" ].set_s(padding); |
581 | } |
582 | |
583 | void ConvertDepthToSpaceOperator(const Model& model, |
584 | const DepthToSpaceOperator& src_op, |
585 | GraphDef* tensorflow_graph) { |
586 | tensorflow::NodeDef* op = tensorflow_graph->add_node(); |
587 | op->set_op("DepthToSpace" ); |
588 | op->set_name(src_op.outputs[0]); |
589 | *op->add_input() = src_op.inputs[0]; |
590 | (*op->mutable_attr())["T" ].set_type(DT_FLOAT); |
591 | (*op->mutable_attr())["block_size" ].set_i(src_op.block_size); |
592 | } |
593 | |
594 | void ConvertSpaceToDepthOperator(const Model& model, |
595 | const SpaceToDepthOperator& src_op, |
596 | GraphDef* tensorflow_graph) { |
597 | tensorflow::NodeDef* op = tensorflow_graph->add_node(); |
598 | op->set_op("SpaceToDepth" ); |
599 | op->set_name(src_op.outputs[0]); |
600 | *op->add_input() = src_op.inputs[0]; |
601 | (*op->mutable_attr())["T" ].set_type(DT_FLOAT); |
602 | (*op->mutable_attr())["block_size" ].set_i(src_op.block_size); |
603 | } |
604 | |
605 | void ConvertFullyConnectedOperator(const Model& model, |
606 | const FullyConnectedOperator& src_op, |
607 | GraphDef* tensorflow_graph) { |
608 | // Reshape input activations to have the shape expected by the MatMul. |
609 | const std::string reshape_output = |
610 | AvailableArrayName(model, src_op.outputs[0] + "/reshape" ); |
611 | const std::string reshape_shape = |
612 | AvailableArrayName(model, reshape_output + "/shape" ); |
613 | const auto& fc_weights_array = model.GetArray(src_op.inputs[1]); |
614 | const auto& fc_weights_shape = fc_weights_array.shape(); |
615 | CHECK_EQ(fc_weights_shape.dimensions_count(), 2); |
616 | CreateMatrixShapeTensorConst(reshape_shape, fc_weights_shape.dims(1), -1, |
617 | tensorflow_graph); |
618 | tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); |
619 | reshape_op->set_op("Reshape" ); |
620 | reshape_op->set_name(reshape_output); |
621 | reshape_op->add_input(src_op.inputs[0]); |
622 | reshape_op->add_input(reshape_shape); |
623 | (*reshape_op->mutable_attr())["T" ].set_type( |
624 | GetTensorFlowDataType(model, src_op.inputs[0])); |
625 | |
626 | const bool has_bias = src_op.inputs.size() >= 3; |
627 | std::string matmul_output = src_op.outputs[0]; |
628 | if (has_bias) { |
629 | matmul_output += "/matmul" ; |
630 | } |
631 | |
632 | // Transpose the RHS input from column-major to row-major to match TensorFlow |
633 | // expectations. This is the inverse of the transpose we do during |
634 | // ResolveTensorFlowMatMul. |
635 | const std::string transpose_output = |
636 | AvailableArrayName(model, matmul_output + "/transpose_weights" ); |
637 | const std::string transpose_perm = |
638 | AvailableArrayName(model, transpose_output + "/perm" ); |
639 | CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph); |
640 | tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); |
641 | transpose_op->set_op("Transpose" ); |
642 | transpose_op->set_name(transpose_output); |
643 | *transpose_op->add_input() = src_op.inputs[1]; |
644 | *transpose_op->add_input() = transpose_perm; |
645 | (*transpose_op->mutable_attr())["T" ].set_type( |
646 | GetTensorFlowDataType(model, src_op.inputs[1])); |
647 | (*transpose_op->mutable_attr())["Tperm" ].set_type(DT_INT32); |
648 | |
649 | tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); |
650 | matmul_op->set_op("MatMul" ); |
651 | matmul_op->set_name(matmul_output); |
652 | *matmul_op->add_input() = reshape_output; |
653 | *matmul_op->add_input() = transpose_op->name(); |
654 | (*matmul_op->mutable_attr())["T" ].set_type( |
655 | GetTensorFlowDataType(model, src_op.inputs[0])); |
656 | (*matmul_op->mutable_attr())["transpose_a" ].set_b(false); |
657 | (*matmul_op->mutable_attr())["transpose_b" ].set_b(false); |
658 | CHECK(model.HasArray(src_op.inputs[1])); |
659 | |
660 | // Add the bias, if it exists. |
661 | if (has_bias) { |
662 | tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); |
663 | biasadd_op->set_op("BiasAdd" ); |
664 | biasadd_op->set_name(src_op.outputs[0]); |
665 | biasadd_op->add_input(matmul_output); |
666 | biasadd_op->add_input(src_op.inputs[2]); |
667 | (*biasadd_op->mutable_attr())["T" ].set_type( |
668 | GetTensorFlowDataType(model, src_op.inputs[0])); |
669 | CHECK(model.HasArray(src_op.inputs[2])); |
670 | const auto& bias_array = model.GetArray(src_op.inputs[2]); |
671 | // TODO(b/62904716) Bias arrays should be 1-D, and used directly. |
672 | Shape bias_shape_1d = bias_array.shape(); |
673 | UnextendShape(&bias_shape_1d, 1); |
674 | CHECK(bias_array.buffer); |
675 | CHECK(bias_array.buffer->type == ArrayDataType::kFloat); |
676 | const float* bias_data = |
677 | bias_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
678 | ConvertFloatTensorConst(WalkUpToConstantArray(model, src_op.inputs[2]), |
679 | bias_shape_1d, bias_data, AxesOrder::kOneAxis, |
680 | AxesOrder::kOneAxis, tensorflow_graph, |
681 | LegacyScalarPolicy::kDoCreateLegacyScalars); |
682 | } |
683 | } |
684 | |
685 | void ConvertAddOperator(const Model& model, const AddOperator& src_op, |
686 | GraphDef* tensorflow_graph) { |
687 | tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); |
688 | add_op->set_op("Add" ); |
689 | add_op->set_name(src_op.outputs[0]); |
690 | CHECK_EQ(src_op.inputs.size(), 2); |
691 | *add_op->add_input() = src_op.inputs[0]; |
692 | *add_op->add_input() = src_op.inputs[1]; |
693 | (*add_op->mutable_attr())["T" ].set_type( |
694 | GetTensorFlowDataType(model, src_op.outputs[0])); |
695 | } |
696 | |
697 | void ConvertAddNOperator(const Model& model, const AddNOperator& src_op, |
698 | GraphDef* tensorflow_graph) { |
699 | tensorflow::NodeDef* add_op = tensorflow_graph->add_node(); |
700 | add_op->set_op("AddN" ); |
701 | add_op->set_name(src_op.outputs[0]); |
702 | for (const auto& input : src_op.inputs) { |
703 | *add_op->add_input() = input; |
704 | } |
705 | (*add_op->mutable_attr())["N" ].set_i(src_op.inputs.size()); |
706 | (*add_op->mutable_attr())["T" ].set_type( |
707 | GetTensorFlowDataType(model, src_op.outputs[0])); |
708 | } |
709 | |
710 | void ConvertMulOperator(const Model& model, const MulOperator& src_op, |
711 | GraphDef* tensorflow_graph) { |
712 | tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); |
713 | mul_op->set_op("Mul" ); |
714 | mul_op->set_name(src_op.outputs[0]); |
715 | CHECK_EQ(src_op.inputs.size(), 2); |
716 | *mul_op->add_input() = src_op.inputs[0]; |
717 | *mul_op->add_input() = src_op.inputs[1]; |
718 | (*mul_op->mutable_attr())["T" ].set_type( |
719 | GetTensorFlowDataType(model, src_op.outputs[0])); |
720 | } |
721 | |
722 | void ConvertDivOperator(const Model& model, const DivOperator& src_op, |
723 | GraphDef* tensorflow_graph) { |
724 | tensorflow::NodeDef* div_op = tensorflow_graph->add_node(); |
725 | div_op->set_op("Div" ); |
726 | div_op->set_name(src_op.outputs[0]); |
727 | CHECK_EQ(src_op.inputs.size(), 2); |
728 | *div_op->add_input() = src_op.inputs[0]; |
729 | *div_op->add_input() = src_op.inputs[1]; |
730 | (*div_op->mutable_attr())["T" ].set_type( |
731 | GetTensorFlowDataType(model, src_op.outputs[0])); |
732 | } |
733 | |
734 | void ConvertReluOperator(const Model& model, const ReluOperator& src_op, |
735 | GraphDef* tensorflow_graph) { |
736 | tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); |
737 | relu_op->set_op("Relu" ); |
738 | relu_op->set_name(src_op.outputs[0]); |
739 | *relu_op->add_input() = src_op.inputs[0]; |
740 | (*relu_op->mutable_attr())["T" ].set_type( |
741 | GetTensorFlowDataType(model, src_op.outputs[0])); |
742 | } |
743 | |
744 | void ConvertRelu1Operator(const Relu1Operator& src_op, |
745 | GraphDef* tensorflow_graph) { |
746 | const std::string max_bounds = src_op.outputs[0] + "/max_bounds" ; |
747 | const std::string min_bounds = src_op.outputs[0] + "/min_bounds" ; |
748 | const std::string max_output = src_op.outputs[0] + "/max_output" ; |
749 | |
750 | tensorflow::NodeDef* max_bounds_const_op = tensorflow_graph->add_node(); |
751 | max_bounds_const_op->set_op("Const" ); |
752 | max_bounds_const_op->set_name(max_bounds); |
753 | (*max_bounds_const_op->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
754 | auto* max_bounds_const_op_tensor = |
755 | (*max_bounds_const_op->mutable_attr())["value" ].mutable_tensor(); |
756 | max_bounds_const_op_tensor->set_dtype(DT_FLOAT); |
757 | max_bounds_const_op_tensor->add_float_val(-1.0f); |
758 | |
759 | tensorflow::NodeDef* min_bounds_const_op = tensorflow_graph->add_node(); |
760 | min_bounds_const_op->set_op("Const" ); |
761 | min_bounds_const_op->set_name(min_bounds); |
762 | (*min_bounds_const_op->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
763 | auto* min_bounds_const_op_tensor = |
764 | (*min_bounds_const_op->mutable_attr())["value" ].mutable_tensor(); |
765 | min_bounds_const_op_tensor->set_dtype(DT_FLOAT); |
766 | min_bounds_const_op_tensor->add_float_val(1.0f); |
767 | |
768 | tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); |
769 | max_op->set_op("Maximum" ); |
770 | max_op->set_name(max_output); |
771 | *max_op->add_input() = src_op.inputs[0]; |
772 | *max_op->add_input() = max_bounds; |
773 | (*max_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
774 | |
775 | tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); |
776 | min_op->set_op("Minimum" ); |
777 | min_op->set_name(src_op.outputs[0]); |
778 | *min_op->add_input() = max_output; |
779 | *min_op->add_input() = min_bounds; |
780 | (*min_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
781 | } |
782 | |
783 | void ConvertRelu6Operator(const Relu6Operator& src_op, |
784 | GraphDef* tensorflow_graph) { |
785 | tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); |
786 | relu_op->set_op("Relu6" ); |
787 | relu_op->set_name(src_op.outputs[0]); |
788 | *relu_op->add_input() = src_op.inputs[0]; |
789 | (*relu_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
790 | } |
791 | |
792 | void ConvertLogOperator(const LogOperator& src_op, GraphDef* tensorflow_graph) { |
793 | tensorflow::NodeDef* op = tensorflow_graph->add_node(); |
794 | op->set_op("Log" ); |
795 | op->set_name(src_op.outputs[0]); |
796 | CHECK_EQ(src_op.inputs.size(), 1); |
797 | *op->add_input() = src_op.inputs[0]; |
798 | (*op->mutable_attr())["T" ].set_type(DT_FLOAT); |
799 | } |
800 | |
801 | void ConvertLogisticOperator(const LogisticOperator& src_op, |
802 | GraphDef* tensorflow_graph) { |
803 | tensorflow::NodeDef* relu_op = tensorflow_graph->add_node(); |
804 | relu_op->set_op("Sigmoid" ); |
805 | relu_op->set_name(src_op.outputs[0]); |
806 | *relu_op->add_input() = src_op.inputs[0]; |
807 | (*relu_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
808 | } |
809 | |
810 | void ConvertTanhOperator(const TanhOperator& src_op, |
811 | GraphDef* tensorflow_graph) { |
812 | tensorflow::NodeDef* tanh_op = tensorflow_graph->add_node(); |
813 | tanh_op->set_op("Tanh" ); |
814 | tanh_op->set_name(src_op.outputs[0]); |
815 | *tanh_op->add_input() = src_op.inputs[0]; |
816 | (*tanh_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
817 | } |
818 | |
819 | void ConvertSoftmaxOperator(const Model& model, const SoftmaxOperator& src_op, |
820 | GraphDef* tensorflow_graph) { |
821 | std::string softmax_input; |
822 | Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); |
823 | if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { |
824 | softmax_input = src_op.inputs[0]; |
825 | } else { |
826 | // Insert a reshape operator that reduces the dimensions down to the 2 that |
827 | // are required for TensorFlow Logits. |
828 | const std::string reshape_output = |
829 | src_op.outputs[0] + "/softmax_insert_reshape" ; |
830 | const std::string softmax_size = src_op.outputs[0] + "/softmax_insert_size" ; |
831 | softmax_input = reshape_output; |
832 | |
833 | tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); |
834 | reshape_op->set_op("Reshape" ); |
835 | reshape_op->set_name(reshape_output); |
836 | *reshape_op->add_input() = src_op.inputs[0]; |
837 | *reshape_op->add_input() = softmax_size; |
838 | (*reshape_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
839 | |
840 | const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); |
841 | int32_t flattened_size = 1; |
842 | for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { |
843 | flattened_size *= input_shape.dims(i); |
844 | } |
845 | const std::vector<int32> shape_data = { |
846 | flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; |
847 | CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); |
848 | } |
849 | |
850 | tensorflow::NodeDef* softmax_op = tensorflow_graph->add_node(); |
851 | softmax_op->set_op("Softmax" ); |
852 | softmax_op->set_name(src_op.outputs[0]); |
853 | *softmax_op->add_input() = softmax_input; |
854 | // TensorFlow's Softmax doesn't seem to admit a 'beta' parameter |
855 | CHECK_EQ(src_op.beta, 1.f); |
856 | (*softmax_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
857 | } |
858 | |
859 | void ConvertLogSoftmaxOperator(const Model& model, |
860 | const LogSoftmaxOperator& src_op, |
861 | GraphDef* tensorflow_graph) { |
862 | std::string softmax_input; |
863 | Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); |
864 | if (providing_op != nullptr && providing_op->type == OperatorType::kReshape) { |
865 | softmax_input = src_op.inputs[0]; |
866 | } else { |
867 | // Insert a reshape operator that reduces the dimensions down to the 2 that |
868 | // are required for TensorFlow Logits. |
869 | const std::string reshape_output = |
870 | src_op.outputs[0] + "/log_softmax_insert_reshape" ; |
871 | const std::string softmax_size = |
872 | src_op.outputs[0] + "/log_softmax_insert_size" ; |
873 | softmax_input = reshape_output; |
874 | |
875 | tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); |
876 | reshape_op->set_op("Reshape" ); |
877 | reshape_op->set_name(reshape_output); |
878 | *reshape_op->add_input() = src_op.inputs[0]; |
879 | *reshape_op->add_input() = softmax_size; |
880 | (*reshape_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
881 | |
882 | const auto& input_shape = model.GetArray(src_op.inputs[0]).shape(); |
883 | int32_t flattened_size = 1; |
884 | for (int i = 0; i < input_shape.dimensions_count() - 1; ++i) { |
885 | flattened_size *= input_shape.dims(i); |
886 | } |
887 | const std::vector<int32> shape_data = { |
888 | flattened_size, input_shape.dims(input_shape.dimensions_count() - 1)}; |
889 | CreateReshapeShapeTensorConst(softmax_size, shape_data, tensorflow_graph); |
890 | } |
891 | |
892 | tensorflow::NodeDef* log_softmax_op = tensorflow_graph->add_node(); |
893 | log_softmax_op->set_op("LogSoftmax" ); |
894 | log_softmax_op->set_name(src_op.outputs[0]); |
895 | *log_softmax_op->add_input() = softmax_input; |
896 | (*log_softmax_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
897 | } |
898 | |
899 | void ConvertL2NormalizationOperator(const L2NormalizationOperator& src_op, |
900 | GraphDef* tensorflow_graph) { |
901 | const std::string square_output = src_op.outputs[0] + "/square" ; |
902 | const std::string sum_reduction_indices = |
903 | src_op.outputs[0] + "/reduction_indices" ; |
904 | const std::string sum_output = src_op.outputs[0] + "/sum" ; |
905 | const std::string rsqrt_output = src_op.outputs[0] + "/rsqrt" ; |
906 | const std::string rsqrt_tiled_output = src_op.outputs[0] + "/rsqrt_tiled" ; |
907 | |
908 | tensorflow::NodeDef* sum_reduction_indices_op = tensorflow_graph->add_node(); |
909 | sum_reduction_indices_op->set_op("Const" ); |
910 | sum_reduction_indices_op->set_name(sum_reduction_indices); |
911 | (*sum_reduction_indices_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
912 | auto* sum_reduction_indices_tensor = |
913 | (*sum_reduction_indices_op->mutable_attr())["value" ].mutable_tensor(); |
914 | sum_reduction_indices_tensor->set_dtype(DT_INT32); |
915 | auto* sum_reduction_indices_shape = |
916 | sum_reduction_indices_tensor->mutable_tensor_shape(); |
917 | auto* sum_reduction_indices_dim = sum_reduction_indices_shape->add_dim(); |
918 | sum_reduction_indices_dim->set_size(2); |
919 | sum_reduction_indices_tensor->add_int_val(0); |
920 | sum_reduction_indices_tensor->add_int_val(1); |
921 | |
922 | tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); |
923 | square_op->set_op("Square" ); |
924 | square_op->set_name(square_output); |
925 | *square_op->add_input() = src_op.inputs[0]; |
926 | (*square_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
927 | |
928 | tensorflow::NodeDef* sum_op = tensorflow_graph->add_node(); |
929 | sum_op->set_op("Sum" ); |
930 | sum_op->set_name(sum_output); |
931 | *sum_op->add_input() = square_output; |
932 | *sum_op->add_input() = sum_reduction_indices; |
933 | (*sum_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
934 | |
935 | tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); |
936 | rsqrt_op->set_op("Rsqrt" ); |
937 | rsqrt_op->set_name(rsqrt_output); |
938 | *rsqrt_op->add_input() = sum_output; |
939 | (*rsqrt_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
940 | |
941 | tensorflow::NodeDef* mul_op = tensorflow_graph->add_node(); |
942 | mul_op->set_op("Mul" ); |
943 | mul_op->set_name(src_op.outputs[0]); |
944 | *mul_op->add_input() = src_op.inputs[0]; |
945 | *mul_op->add_input() = rsqrt_output; |
946 | (*mul_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
947 | } |
948 | |
949 | void ConvertLocalResponseNormalizationOperator( |
950 | const LocalResponseNormalizationOperator& src_op, |
951 | GraphDef* tensorflow_graph) { |
952 | tensorflow::NodeDef* lrn_op = tensorflow_graph->add_node(); |
953 | lrn_op->set_op("LRN" ); |
954 | lrn_op->set_name(src_op.outputs[0]); |
955 | *lrn_op->add_input() = src_op.inputs[0]; |
956 | (*lrn_op->mutable_attr())["depth_radius" ].set_i(src_op.range); |
957 | (*lrn_op->mutable_attr())["bias" ].set_f(src_op.bias); |
958 | (*lrn_op->mutable_attr())["alpha" ].set_f(src_op.alpha); |
959 | (*lrn_op->mutable_attr())["beta" ].set_f(src_op.beta); |
960 | } |
961 | |
962 | void ConvertFakeQuantOperator(const FakeQuantOperator& src_op, |
963 | GraphDef* tensorflow_graph) { |
964 | tensorflow::NodeDef* fakequant_op = tensorflow_graph->add_node(); |
965 | fakequant_op->set_op("FakeQuantWithMinMaxArgs" ); |
966 | fakequant_op->set_name(src_op.outputs[0]); |
967 | CHECK_EQ(src_op.inputs.size(), 1); |
968 | *fakequant_op->add_input() = src_op.inputs[0]; |
969 | CHECK(src_op.minmax); |
970 | (*fakequant_op->mutable_attr())["min" ].set_f(src_op.minmax->min); |
971 | (*fakequant_op->mutable_attr())["max" ].set_f(src_op.minmax->max); |
972 | if (src_op.num_bits) { |
973 | (*fakequant_op->mutable_attr())["num_bits" ].set_i(src_op.num_bits); |
974 | } |
975 | if (src_op.narrow_range) { |
976 | (*fakequant_op->mutable_attr())["narrow_range" ].set_b(src_op.narrow_range); |
977 | } |
978 | } |
979 | |
980 | void ConvertMaxPoolOperator(const MaxPoolOperator& src_op, |
981 | GraphDef* tensorflow_graph) { |
982 | tensorflow::NodeDef* maxpool_op = tensorflow_graph->add_node(); |
983 | maxpool_op->set_op("MaxPool" ); |
984 | maxpool_op->set_name(src_op.outputs[0]); |
985 | *maxpool_op->add_input() = src_op.inputs[0]; |
986 | auto& strides = (*maxpool_op->mutable_attr())["strides" ]; |
987 | strides.mutable_list()->add_i(1); |
988 | strides.mutable_list()->add_i(src_op.stride_height); |
989 | strides.mutable_list()->add_i(src_op.stride_width); |
990 | strides.mutable_list()->add_i(1); |
991 | std::string padding; |
992 | if (src_op.padding.type == PaddingType::kSame) { |
993 | padding = "SAME" ; |
994 | } else if (src_op.padding.type == PaddingType::kValid) { |
995 | padding = "VALID" ; |
996 | } else { |
997 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
998 | } |
999 | (*maxpool_op->mutable_attr())["padding" ].set_s(padding); |
1000 | (*maxpool_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1001 | auto& ksize = (*maxpool_op->mutable_attr())["ksize" ]; |
1002 | ksize.mutable_list()->add_i(1); |
1003 | ksize.mutable_list()->add_i(src_op.kheight); |
1004 | ksize.mutable_list()->add_i(src_op.kwidth); |
1005 | ksize.mutable_list()->add_i(1); |
1006 | } |
1007 | |
1008 | void ConvertAveragePoolOperator(const AveragePoolOperator& src_op, |
1009 | GraphDef* tensorflow_graph) { |
1010 | tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); |
1011 | avgpool_op->set_op("AvgPool" ); |
1012 | avgpool_op->set_name(src_op.outputs[0]); |
1013 | *avgpool_op->add_input() = src_op.inputs[0]; |
1014 | auto& strides = (*avgpool_op->mutable_attr())["strides" ]; |
1015 | strides.mutable_list()->add_i(1); |
1016 | strides.mutable_list()->add_i(src_op.stride_height); |
1017 | strides.mutable_list()->add_i(src_op.stride_width); |
1018 | strides.mutable_list()->add_i(1); |
1019 | std::string padding; |
1020 | if (src_op.padding.type == PaddingType::kSame) { |
1021 | padding = "SAME" ; |
1022 | } else if (src_op.padding.type == PaddingType::kValid) { |
1023 | padding = "VALID" ; |
1024 | } else { |
1025 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
1026 | } |
1027 | (*avgpool_op->mutable_attr())["padding" ].set_s(padding); |
1028 | (*avgpool_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1029 | auto& ksize = (*avgpool_op->mutable_attr())["ksize" ]; |
1030 | ksize.mutable_list()->add_i(1); |
1031 | ksize.mutable_list()->add_i(src_op.kheight); |
1032 | ksize.mutable_list()->add_i(src_op.kwidth); |
1033 | ksize.mutable_list()->add_i(1); |
1034 | } |
1035 | |
1036 | void ConvertConcatenationOperator(const Model& model, |
1037 | const ConcatenationOperator& src_op, |
1038 | GraphDef* tensorflow_graph) { |
1039 | tensorflow::NodeDef* dc_op = tensorflow_graph->add_node(); |
1040 | dc_op->set_op("ConcatV2" ); |
1041 | dc_op->set_name(src_op.outputs[0]); |
1042 | const std::string dummy_axis = src_op.outputs[0] + "/axis" ; |
1043 | CreateDummyConcatDimTensorConst(dummy_axis, src_op.axis, tensorflow_graph); |
1044 | for (const auto& input : src_op.inputs) { |
1045 | *dc_op->add_input() = input; |
1046 | } |
1047 | *dc_op->add_input() = dummy_axis; |
1048 | (*dc_op->mutable_attr())["T" ].set_type( |
1049 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1050 | (*dc_op->mutable_attr())["Tidx" ].set_type(DT_INT32); |
1051 | (*dc_op->mutable_attr())["N" ].set_i(src_op.inputs.size()); |
1052 | } |
1053 | |
1054 | void ConvertTensorFlowReshapeOperator(const Model& model, |
1055 | const TensorFlowReshapeOperator& src_op, |
1056 | GraphDef* tensorflow_graph) { |
1057 | tensorflow::NodeDef* reshape_op = tensorflow_graph->add_node(); |
1058 | reshape_op->set_op("Reshape" ); |
1059 | reshape_op->set_name(src_op.outputs[0]); |
1060 | CHECK_EQ(src_op.inputs.size(), 2); |
1061 | *reshape_op->add_input() = src_op.inputs[0]; |
1062 | *reshape_op->add_input() = src_op.inputs[1]; |
1063 | (*reshape_op->mutable_attr())["T" ].set_type( |
1064 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1065 | const auto& shape_array = model.GetArray(src_op.inputs[1]); |
1066 | QCHECK(shape_array.data_type == ArrayDataType::kInt32) |
1067 | << "Only int32 shape is supported." ; |
1068 | QCHECK(shape_array.buffer != nullptr) |
1069 | << "Shape inferred at runtime is not supported." ; |
1070 | const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data; |
1071 | CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph); |
1072 | } |
1073 | |
1074 | void ConvertL2PoolOperator(const L2PoolOperator& src_op, |
1075 | GraphDef* tensorflow_graph) { |
1076 | const std::string square_output = src_op.outputs[0] + "/square" ; |
1077 | const std::string avgpool_output = src_op.outputs[0] + "/avgpool" ; |
1078 | |
1079 | tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); |
1080 | square_op->set_op("Square" ); |
1081 | square_op->set_name(square_output); |
1082 | *square_op->add_input() = src_op.inputs[0]; |
1083 | (*square_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1084 | |
1085 | std::string padding; |
1086 | if (src_op.padding.type == PaddingType::kSame) { |
1087 | padding = "SAME" ; |
1088 | } else if (src_op.padding.type == PaddingType::kValid) { |
1089 | padding = "VALID" ; |
1090 | } else { |
1091 | LOG(FATAL) << "Bad padding (only SAME and VALID are supported)" ; |
1092 | } |
1093 | |
1094 | tensorflow::NodeDef* avgpool_op = tensorflow_graph->add_node(); |
1095 | avgpool_op->set_op("AvgPool" ); |
1096 | avgpool_op->set_name(avgpool_output); |
1097 | *avgpool_op->add_input() = square_output; |
1098 | auto& strides = (*avgpool_op->mutable_attr())["strides" ]; |
1099 | strides.mutable_list()->add_i(1); |
1100 | strides.mutable_list()->add_i(src_op.stride_height); |
1101 | strides.mutable_list()->add_i(src_op.stride_width); |
1102 | strides.mutable_list()->add_i(1); |
1103 | |
1104 | (*avgpool_op->mutable_attr())["padding" ].set_s(padding); |
1105 | (*avgpool_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1106 | auto& ksize = (*avgpool_op->mutable_attr())["ksize" ]; |
1107 | ksize.mutable_list()->add_i(1); |
1108 | ksize.mutable_list()->add_i(src_op.kheight); |
1109 | ksize.mutable_list()->add_i(src_op.kwidth); |
1110 | ksize.mutable_list()->add_i(1); |
1111 | |
1112 | tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); |
1113 | sqrt_op->set_op("Sqrt" ); |
1114 | sqrt_op->set_name(src_op.outputs[0]); |
1115 | *sqrt_op->add_input() = avgpool_output; |
1116 | (*sqrt_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1117 | } |
1118 | |
1119 | void ConvertSquareOperator(const TensorFlowSquareOperator& src_op, |
1120 | GraphDef* tensorflow_graph) { |
1121 | tensorflow::NodeDef* square_op = tensorflow_graph->add_node(); |
1122 | square_op->set_op("Square" ); |
1123 | square_op->set_name(src_op.outputs[0]); |
1124 | CHECK_EQ(src_op.inputs.size(), 1); |
1125 | *square_op->add_input() = src_op.inputs[0]; |
1126 | (*square_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1127 | } |
1128 | |
1129 | void ConvertSqrtOperator(const TensorFlowSqrtOperator& src_op, |
1130 | GraphDef* tensorflow_graph) { |
1131 | tensorflow::NodeDef* sqrt_op = tensorflow_graph->add_node(); |
1132 | sqrt_op->set_op("Sqrt" ); |
1133 | sqrt_op->set_name(src_op.outputs[0]); |
1134 | CHECK_EQ(src_op.inputs.size(), 1); |
1135 | *sqrt_op->add_input() = src_op.inputs[0]; |
1136 | (*sqrt_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1137 | } |
1138 | |
1139 | void ConvertRsqrtOperator(const Model& model, |
1140 | const TensorFlowRsqrtOperator& src_op, |
1141 | GraphDef* tensorflow_graph) { |
1142 | tensorflow::NodeDef* rsqrt_op = tensorflow_graph->add_node(); |
1143 | rsqrt_op->set_op("Rsqrt" ); |
1144 | rsqrt_op->set_name(src_op.outputs[0]); |
1145 | CHECK_EQ(src_op.inputs.size(), 1); |
1146 | *rsqrt_op->add_input() = src_op.inputs[0]; |
1147 | const tensorflow::DataType data_type = |
1148 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1149 | (*rsqrt_op->mutable_attr())["T" ].set_type(data_type); |
1150 | } |
1151 | |
1152 | void ConvertSplitOperator(const Model& model, |
1153 | const TensorFlowSplitOperator& src_op, |
1154 | GraphDef* tensorflow_graph) { |
1155 | tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); |
1156 | split_op->set_op("Split" ); |
1157 | split_op->set_name(src_op.outputs[0]); |
1158 | for (const auto& input : src_op.inputs) { |
1159 | *split_op->add_input() = input; |
1160 | } |
1161 | (*split_op->mutable_attr())["T" ].set_type( |
1162 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1163 | (*split_op->mutable_attr())["num_split" ].set_i(src_op.num_split); |
1164 | const auto& split_dim_array = model.GetArray(src_op.inputs[0]); |
1165 | CHECK(split_dim_array.buffer); |
1166 | CHECK(split_dim_array.data_type == ArrayDataType::kInt32); |
1167 | const auto& split_dim_data = |
1168 | split_dim_array.GetBuffer<ArrayDataType::kInt32>().data; |
1169 | CHECK_EQ(split_dim_data.size(), 1); |
1170 | const int split_dim = split_dim_data[0]; |
1171 | CreateDummyConcatDimTensorConst(src_op.inputs[0], split_dim, |
1172 | tensorflow_graph); |
1173 | } |
1174 | |
1175 | void ConvertSplitVOperator(const Model& model, |
1176 | const TensorFlowSplitVOperator& src_op, |
1177 | GraphDef* tensorflow_graph) { |
1178 | tensorflow::NodeDef* split_v_op = tensorflow_graph->add_node(); |
1179 | split_v_op->set_op("SplitV" ); |
1180 | split_v_op->set_name(src_op.outputs[0]); |
1181 | for (const auto& input : src_op.inputs) { |
1182 | *split_v_op->add_input() = input; |
1183 | } |
1184 | (*split_v_op->mutable_attr())["T" ].set_type( |
1185 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1186 | (*split_v_op->mutable_attr())["Tlen" ].set_type( |
1187 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1188 | (*split_v_op->mutable_attr())["num_split" ].set_i(src_op.num_split); |
1189 | ConvertIntTensorConst(model, src_op.inputs[1], tensorflow_graph); |
1190 | } |
1191 | |
1192 | void ConvertCastOperator(const Model& model, const CastOperator& src_op, |
1193 | GraphDef* tensorflow_graph) { |
1194 | tensorflow::NodeDef* cast_op = tensorflow_graph->add_node(); |
1195 | cast_op->set_op("Cast" ); |
1196 | cast_op->set_name(src_op.outputs[0]); |
1197 | CHECK_EQ(src_op.inputs.size(), 1); |
1198 | *cast_op->add_input() = src_op.inputs[0]; |
1199 | |
1200 | (*cast_op->mutable_attr())["DstT" ].set_type( |
1201 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1202 | (*cast_op->mutable_attr())["SrcT" ].set_type( |
1203 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1204 | } |
1205 | |
1206 | void ConvertFloorOperator(const Model& model, const FloorOperator& src_op, |
1207 | GraphDef* tensorflow_graph) { |
1208 | tensorflow::NodeDef* floor_op = tensorflow_graph->add_node(); |
1209 | floor_op->set_op("Floor" ); |
1210 | floor_op->set_name(src_op.outputs[0]); |
1211 | CHECK_EQ(src_op.inputs.size(), 1); |
1212 | *floor_op->add_input() = src_op.inputs[0]; |
1213 | (*floor_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1214 | } |
1215 | |
1216 | void ConvertCeilOperator(const Model& model, const CeilOperator& src_op, |
1217 | GraphDef* tensorflow_graph) { |
1218 | tensorflow::NodeDef* ceil_op = tensorflow_graph->add_node(); |
1219 | ceil_op->set_op("Ceil" ); |
1220 | ceil_op->set_name(src_op.outputs[0]); |
1221 | CHECK_EQ(src_op.inputs.size(), 1); |
1222 | *ceil_op->add_input() = src_op.inputs[0]; |
1223 | (*ceil_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1224 | } |
1225 | |
1226 | void ConvertRoundOperator(const Model& model, const RoundOperator& src_op, |
1227 | GraphDef* tensorflow_graph) { |
1228 | tensorflow::NodeDef* round_op = tensorflow_graph->add_node(); |
1229 | round_op->set_op("Round" ); |
1230 | round_op->set_name(src_op.outputs[0]); |
1231 | CHECK_EQ(src_op.inputs.size(), 1); |
1232 | *round_op->add_input() = src_op.inputs[0]; |
1233 | (*round_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1234 | } |
1235 | |
1236 | void ConvertGatherOperator(const Model& model, const GatherOperator& src_op, |
1237 | GraphDef* tensorflow_graph) { |
1238 | tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); |
1239 | gather_op->set_op("GatherV2" ); |
1240 | gather_op->set_name(src_op.outputs[0]); |
1241 | *gather_op->add_input() = src_op.inputs[0]; |
1242 | *gather_op->add_input() = src_op.inputs[1]; |
1243 | |
1244 | if (!src_op.axis) { |
1245 | // Dynamic axis. |
1246 | CHECK_EQ(src_op.inputs.size(), 3); |
1247 | *gather_op->add_input() = src_op.inputs[2]; |
1248 | } else { |
1249 | // Constant axis. |
1250 | CHECK_EQ(src_op.inputs.size(), 2); |
1251 | const std::string gather_axis = |
1252 | AvailableArrayName(model, gather_op->name() + "/axis" ); |
1253 | CreateIntTensorConst(gather_axis, {src_op.axis.value()}, {}, |
1254 | tensorflow_graph); |
1255 | *gather_op->add_input() = gather_axis; |
1256 | } |
1257 | |
1258 | (*gather_op->mutable_attr())["Tindices" ].set_type(DT_INT32); |
1259 | (*gather_op->mutable_attr())["Taxis" ].set_type(DT_INT32); |
1260 | const tensorflow::DataType params_type = |
1261 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1262 | (*gather_op->mutable_attr())["Tparams" ].set_type(params_type); |
1263 | } |
1264 | |
1265 | void ConvertArgMaxOperator(const Model& model, const ArgMaxOperator& src_op, |
1266 | GraphDef* tensorflow_graph) { |
1267 | tensorflow::NodeDef* argmax_op = tensorflow_graph->add_node(); |
1268 | argmax_op->set_op("ArgMax" ); |
1269 | argmax_op->set_name(src_op.outputs[0]); |
1270 | CHECK_EQ(src_op.inputs.size(), 2); |
1271 | *argmax_op->add_input() = src_op.inputs[0]; |
1272 | *argmax_op->add_input() = src_op.inputs[1]; |
1273 | (*argmax_op->mutable_attr())["T" ].set_type( |
1274 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1275 | (*argmax_op->mutable_attr())["Tidx" ].set_type( |
1276 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1277 | (*argmax_op->mutable_attr())["output_type" ].set_type( |
1278 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1279 | } |
1280 | |
1281 | void ConvertArgMinOperator(const Model& model, const ArgMinOperator& src_op, |
1282 | GraphDef* tensorflow_graph) { |
1283 | tensorflow::NodeDef* argmin_op = tensorflow_graph->add_node(); |
1284 | argmin_op->set_op("ArgMin" ); |
1285 | argmin_op->set_name(src_op.outputs[0]); |
1286 | CHECK_EQ(src_op.inputs.size(), 2); |
1287 | *argmin_op->add_input() = src_op.inputs[0]; |
1288 | *argmin_op->add_input() = src_op.inputs[1]; |
1289 | (*argmin_op->mutable_attr())["T" ].set_type( |
1290 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1291 | (*argmin_op->mutable_attr())["Tidx" ].set_type( |
1292 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1293 | (*argmin_op->mutable_attr())["output_type" ].set_type( |
1294 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1295 | } |
1296 | |
1297 | void ConvertTransposeOperator(const Model& model, |
1298 | const TransposeOperator& src_op, |
1299 | GraphDef* tensorflow_graph) { |
1300 | tensorflow::NodeDef* transpose_op = tensorflow_graph->add_node(); |
1301 | transpose_op->set_op("Transpose" ); |
1302 | transpose_op->set_name(src_op.outputs[0]); |
1303 | CHECK_EQ(src_op.inputs.size(), 2); |
1304 | *transpose_op->add_input() = src_op.inputs[0]; |
1305 | *transpose_op->add_input() = src_op.inputs[1]; |
1306 | (*transpose_op->mutable_attr())["T" ].set_type( |
1307 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1308 | (*transpose_op->mutable_attr())["Tperm" ].set_type( |
1309 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1310 | } |
1311 | |
1312 | void ConvertTensorFlowShapeOperator(const Model& model, |
1313 | const TensorFlowShapeOperator& src_op, |
1314 | GraphDef* tensorflow_graph) { |
1315 | tensorflow::NodeDef* shape_op = tensorflow_graph->add_node(); |
1316 | shape_op->set_op("Shape" ); |
1317 | shape_op->set_name(src_op.outputs[0]); |
1318 | CHECK_EQ(src_op.inputs.size(), 1); |
1319 | *shape_op->add_input() = src_op.inputs[0]; |
1320 | (*shape_op->mutable_attr())["T" ].set_type( |
1321 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1322 | (*shape_op->mutable_attr())["out_type" ].set_type( |
1323 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1324 | } |
1325 | |
1326 | void ConvertRankOperator(const Model& model, |
1327 | const TensorFlowRankOperator& src_op, |
1328 | GraphDef* tensorflow_graph) { |
1329 | tensorflow::NodeDef* rank_op = tensorflow_graph->add_node(); |
1330 | rank_op->set_op("Rank" ); |
1331 | rank_op->set_name(src_op.outputs[0]); |
1332 | CHECK_EQ(src_op.inputs.size(), 1); |
1333 | *rank_op->add_input() = src_op.inputs[0]; |
1334 | (*rank_op->mutable_attr())["T" ].set_type( |
1335 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1336 | } |
1337 | |
1338 | void ConvertRangeOperator(const Model& model, const RangeOperator& src_op, |
1339 | GraphDef* tensorflow_graph) { |
1340 | tensorflow::NodeDef* range_op = tensorflow_graph->add_node(); |
1341 | range_op->set_op("Range" ); |
1342 | range_op->set_name(src_op.outputs[0]); |
1343 | CHECK_EQ(src_op.inputs.size(), 3); |
1344 | *range_op->add_input() = src_op.inputs[0]; |
1345 | *range_op->add_input() = src_op.inputs[1]; |
1346 | *range_op->add_input() = src_op.inputs[2]; |
1347 | (*range_op->mutable_attr())["Tidx" ].set_type( |
1348 | GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0])); |
1349 | } |
1350 | |
1351 | void ConvertPackOperator(const Model& model, const PackOperator& src_op, |
1352 | GraphDef* tensorflow_graph) { |
1353 | tensorflow::NodeDef* pack_op = tensorflow_graph->add_node(); |
1354 | pack_op->set_op("Pack" ); |
1355 | pack_op->set_name(src_op.outputs[0]); |
1356 | for (const auto& input : src_op.inputs) { |
1357 | *pack_op->add_input() = input; |
1358 | } |
1359 | (*pack_op->mutable_attr())["axis" ].set_i(src_op.axis); |
1360 | (*pack_op->mutable_attr())["N" ].set_i(src_op.inputs.size()); |
1361 | (*pack_op->mutable_attr())["T" ].set_type( |
1362 | GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0])); |
1363 | } |
1364 | |
1365 | void ConvertFillOperator(const Model& model, const FillOperator& src_op, |
1366 | GraphDef* tensorflow_graph) { |
1367 | tensorflow::NodeDef* fill_op = tensorflow_graph->add_node(); |
1368 | fill_op->set_op("Fill" ); |
1369 | fill_op->set_name(src_op.outputs[0]); |
1370 | CHECK_EQ(src_op.inputs.size(), 2); |
1371 | *fill_op->add_input() = src_op.inputs[0]; |
1372 | *fill_op->add_input() = src_op.inputs[1]; |
1373 | (*fill_op->mutable_attr())["index_type" ].set_type( |
1374 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1375 | (*fill_op->mutable_attr())["T" ].set_type( |
1376 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1377 | } |
1378 | |
1379 | void ConvertFloorDivOperator(const Model& model, const FloorDivOperator& src_op, |
1380 | GraphDef* tensorflow_graph) { |
1381 | tensorflow::NodeDef* floor_div_op = tensorflow_graph->add_node(); |
1382 | floor_div_op->set_op("FloorDiv" ); |
1383 | floor_div_op->set_name(src_op.outputs[0]); |
1384 | CHECK_EQ(src_op.inputs.size(), 2); |
1385 | *floor_div_op->add_input() = src_op.inputs[0]; |
1386 | *floor_div_op->add_input() = src_op.inputs[1]; |
1387 | (*floor_div_op->mutable_attr())["T" ].set_type( |
1388 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1389 | } |
1390 | |
1391 | void ConvertFloorModOperator(const Model& model, const FloorModOperator& src_op, |
1392 | GraphDef* tensorflow_graph) { |
1393 | tensorflow::NodeDef* floor_mod_op = tensorflow_graph->add_node(); |
1394 | floor_mod_op->set_op("FloorMod" ); |
1395 | floor_mod_op->set_name(src_op.outputs[0]); |
1396 | DCHECK_EQ(src_op.inputs.size(), 2); |
1397 | *floor_mod_op->add_input() = src_op.inputs[0]; |
1398 | *floor_mod_op->add_input() = src_op.inputs[1]; |
1399 | (*floor_mod_op->mutable_attr())["T" ].set_type( |
1400 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1401 | } |
1402 | |
1403 | void ConvertExpandDimsOperator(const Model& model, |
1404 | const ExpandDimsOperator& src_op, |
1405 | GraphDef* tensorflow_graph) { |
1406 | tensorflow::NodeDef* expand_dims_op = tensorflow_graph->add_node(); |
1407 | expand_dims_op->set_op("ExpandDims" ); |
1408 | expand_dims_op->set_name(src_op.outputs[0]); |
1409 | CHECK_EQ(src_op.inputs.size(), 2); |
1410 | *expand_dims_op->add_input() = src_op.inputs[0]; |
1411 | *expand_dims_op->add_input() = src_op.inputs[1]; |
1412 | (*expand_dims_op->mutable_attr())["T" ].set_type( |
1413 | GetTensorFlowDataType(model, src_op.inputs[0])); |
1414 | (*expand_dims_op->mutable_attr())["Tdim" ].set_type( |
1415 | GetTensorFlowDataType(model, src_op.inputs[1])); |
1416 | } |
1417 | |
1418 | void ConvertResizeBilinearOperator(const Model& model, |
1419 | const ResizeBilinearOperator& src_op, |
1420 | GraphDef* tensorflow_graph) { |
1421 | tensorflow::NodeDef* resize_op = tensorflow_graph->add_node(); |
1422 | resize_op->set_op("ResizeBilinear" ); |
1423 | resize_op->set_name(src_op.outputs[0]); |
1424 | CHECK_EQ(src_op.inputs.size(), 2); |
1425 | *resize_op->add_input() = src_op.inputs[0]; |
1426 | *resize_op->add_input() = src_op.inputs[1]; |
1427 | (*resize_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1428 | (*resize_op->mutable_attr())["align_corners" ].set_b(src_op.align_corners); |
1429 | (*resize_op->mutable_attr())["half_pixel_centers" ].set_b( |
1430 | src_op.half_pixel_centers); |
1431 | } |
1432 | |
1433 | void ConvertResizeNearestNeighborOperator( |
1434 | const Model& model, const ResizeNearestNeighborOperator& src_op, |
1435 | GraphDef* tensorflow_graph) { |
1436 | tensorflow::NodeDef* resize_op = tensorflow_graph->add_node(); |
1437 | resize_op->set_op("ResizeNearestNeighbor" ); |
1438 | resize_op->set_name(src_op.outputs[0]); |
1439 | CHECK_EQ(src_op.inputs.size(), 2); |
1440 | *resize_op->add_input() = src_op.inputs[0]; |
1441 | *resize_op->add_input() = src_op.inputs[1]; |
1442 | (*resize_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1443 | (*resize_op->mutable_attr())["align_corners" ].set_b(src_op.align_corners); |
1444 | (*resize_op->mutable_attr())["half_pixel_centers" ].set_b( |
1445 | src_op.half_pixel_centers); |
1446 | } |
1447 | |
1448 | void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op, |
1449 | GraphDef* tensorflow_graph) { |
1450 | tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node(); |
1451 | onehot_op->set_op("OneHot" ); |
1452 | onehot_op->set_name(src_op.outputs[0]); |
1453 | CHECK_EQ(src_op.inputs.size(), 4); |
1454 | for (const auto& input : src_op.inputs) { |
1455 | *onehot_op->add_input() = input; |
1456 | } |
1457 | (*onehot_op->mutable_attr())["T" ].set_type( |
1458 | GetTensorFlowDataType(model, src_op.outputs[0])); |
1459 | (*onehot_op->mutable_attr())["axis" ].set_i(src_op.axis); |
1460 | } |
1461 | |
1462 | namespace { |
1463 | // TODO(aselle): Remove when available in absl |
1464 | absl::string_view FindLongestCommonPrefix(absl::string_view a, |
1465 | absl::string_view b) { |
1466 | if (a.empty() || b.empty()) return absl::string_view(); |
1467 | |
1468 | const char* pa = a.data(); |
1469 | const char* pb = b.data(); |
1470 | std::string::difference_type count = 0; |
1471 | const std::string::difference_type limit = std::min(a.size(), b.size()); |
1472 | while (count < limit && *pa == *pb) { |
1473 | ++pa; |
1474 | ++pb; |
1475 | ++count; |
1476 | } |
1477 | |
1478 | return absl::string_view(a.data(), count); |
1479 | } |
1480 | } // namespace |
1481 | |
1482 | void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, |
1483 | GraphDef* tensorflow_graph) { |
1484 | // Find the base name |
1485 | const std::string base( |
1486 | FindLongestCommonPrefix(src_op.outputs[LstmCellOperator::STATE_OUTPUT], |
1487 | src_op.outputs[LstmCellOperator::ACTIV_OUTPUT])); |
1488 | |
1489 | // Concatenate inputs |
1490 | const std::string concat_output = base + "basic_lstm_cell/concat" ; |
1491 | // Op names have been chosen to match the tf.slim LSTM naming |
1492 | // as closely as possible. |
1493 | const int axis = |
1494 | model.GetArray(src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]) |
1495 | .shape() |
1496 | .dimensions_count() - |
1497 | 1; |
1498 | // Note that DATA_INPUT may have extra size 1 dimensions, but TF concat |
1499 | // works the same since the tensor has the same underlying data layout. |
1500 | const std::string axis_output = concat_output + "/axis" ; |
1501 | CreateDummyConcatDimTensorConst(axis_output, axis, tensorflow_graph); |
1502 | tensorflow::NodeDef* concat_op = tensorflow_graph->add_node(); |
1503 | concat_op->set_op("ConcatV2" ); |
1504 | concat_op->set_name(concat_output); |
1505 | *concat_op->add_input() = src_op.inputs[LstmCellOperator::DATA_INPUT]; |
1506 | *concat_op->add_input() = src_op.inputs[LstmCellOperator::PREV_ACTIV_INPUT]; |
1507 | *concat_op->add_input() = axis_output; |
1508 | (*concat_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1509 | (*concat_op->mutable_attr())["Tidx" ].set_type(DT_INT32); |
1510 | (*concat_op->mutable_attr())["N" ].set_i(2); // Number of inputs |
1511 | |
1512 | // Write weights |
1513 | const std::string weights_output = base + "weights" ; |
1514 | CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); |
1515 | const std::string weights_name = WalkUpToConstantArray( |
1516 | model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); |
1517 | const auto& weights_array = model.GetArray(weights_name); |
1518 | // Convert 4D FullyConnected weights into 2D matrix |
1519 | const auto& weights_shape = weights_array.shape(); |
1520 | CHECK_EQ(weights_shape.dimensions_count(), 2); |
1521 | CHECK(weights_array.buffer); |
1522 | CHECK(weights_array.buffer->type == ArrayDataType::kFloat); |
1523 | const float* weights_data = |
1524 | weights_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
1525 | ConvertFloatTensorConst(weights_output, weights_shape, weights_data, |
1526 | AxesOrder::kCR, AxesOrder::kRC, tensorflow_graph); |
1527 | |
1528 | // Fully connected matrix multiply |
1529 | const std::string matmul_output = base + "MatMul" ; |
1530 | tensorflow::NodeDef* matmul_op = tensorflow_graph->add_node(); |
1531 | matmul_op->set_op("MatMul" ); |
1532 | matmul_op->set_name(matmul_output); |
1533 | *matmul_op->add_input() = concat_output; |
1534 | *matmul_op->add_input() = weights_output; |
1535 | (*matmul_op->mutable_attr())["transpose_a" ].set_b(false); |
1536 | (*matmul_op->mutable_attr())["transpose_b" ].set_b(false); |
1537 | (*matmul_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1538 | |
1539 | // Write biases |
1540 | const std::string biases_output = base + "biases" ; |
1541 | CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT])); |
1542 | const std::string bias_name = WalkUpToConstantArray( |
1543 | model, src_op.inputs[LstmCellOperator::BIASES_INPUT]); |
1544 | const auto& bias_array = model.GetArray(bias_name); |
1545 | // TODO(b/62904716) Bias arrays should be 1-D, and used directly. |
1546 | Shape bias_shape_1d = bias_array.shape(); |
1547 | UnextendShape(&bias_shape_1d, 1); |
1548 | CHECK(bias_array.buffer); |
1549 | CHECK(bias_array.buffer->type == ArrayDataType::kFloat); |
1550 | const float* bias_data = |
1551 | bias_array.GetBuffer<ArrayDataType::kFloat>().data.data(); |
1552 | ConvertFloatTensorConst(biases_output, bias_shape_1d, bias_data, |
1553 | AxesOrder::kOneAxis, AxesOrder::kOneAxis, |
1554 | tensorflow_graph, |
1555 | LegacyScalarPolicy::kDoCreateLegacyScalars); |
1556 | |
1557 | // Add biases |
1558 | std::string biasadd_output = base + "BiasAdd" ; |
1559 | tensorflow::NodeDef* biasadd_op = tensorflow_graph->add_node(); |
1560 | biasadd_op->set_op("BiasAdd" ); |
1561 | biasadd_op->set_name(biasadd_output); |
1562 | biasadd_op->add_input(matmul_output); |
1563 | biasadd_op->add_input(biases_output); |
1564 | (*biasadd_op->mutable_attr())["data_format" ].set_s("NHWC" ); |
1565 | (*biasadd_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1566 | |
1567 | // Split |
1568 | std::string split_dim_output = base + "split/split_dim" ; |
1569 | // The dimension is the same as the concatenation dimension |
1570 | CreateDummyConcatDimTensorConst(split_dim_output, axis, tensorflow_graph); |
1571 | std::string split_output = base + "split" ; |
1572 | tensorflow::NodeDef* split_op = tensorflow_graph->add_node(); |
1573 | split_op->set_op("Split" ); |
1574 | split_op->set_name(split_output); |
1575 | *split_op->add_input() = split_dim_output; |
1576 | *split_op->add_input() = biasadd_output; |
1577 | (*split_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1578 | (*split_op->mutable_attr())["num_split" ].set_i(4); // Split into four outputs |
1579 | |
1580 | // Activation functions and memory computations |
1581 | const std::string tanh_0_output = base + "Tanh" ; |
1582 | tensorflow::NodeDef* tanh_0_op = tensorflow_graph->add_node(); |
1583 | tanh_0_op->set_op("Tanh" ); |
1584 | tanh_0_op->set_name(tanh_0_output); |
1585 | *tanh_0_op->add_input() = split_output + ":1" ; |
1586 | (*tanh_0_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1587 | |
1588 | const std::string sigmoid_1_output = base + "Sigmoid_1" ; |
1589 | tensorflow::NodeDef* logistic_1_op = tensorflow_graph->add_node(); |
1590 | logistic_1_op->set_op("Sigmoid" ); |
1591 | logistic_1_op->set_name(sigmoid_1_output); |
1592 | *logistic_1_op->add_input() = split_output; |
1593 | (*logistic_1_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1594 | |
1595 | const std::string mul_1_output = base + "mul_1" ; |
1596 | tensorflow::NodeDef* mul_1_op = tensorflow_graph->add_node(); |
1597 | mul_1_op->set_op("Mul" ); |
1598 | mul_1_op->set_name(mul_1_output); |
1599 | *mul_1_op->add_input() = sigmoid_1_output; |
1600 | *mul_1_op->add_input() = tanh_0_output; |
1601 | (*mul_1_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1602 | |
1603 | const std::string sigmoid_0_output = base + "Sigmoid" ; |
1604 | tensorflow::NodeDef* logistic_2_op = tensorflow_graph->add_node(); |
1605 | logistic_2_op->set_op("Sigmoid" ); |
1606 | logistic_2_op->set_name(sigmoid_0_output); |
1607 | *logistic_2_op->add_input() = split_output + ":2" ; |
1608 | (*logistic_2_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1609 | |
1610 | const std::string sigmoid_2_output = base + "Sigmoid_2" ; |
1611 | tensorflow::NodeDef* logistic_3_op = tensorflow_graph->add_node(); |
1612 | logistic_3_op->set_op("Sigmoid" ); |
1613 | logistic_3_op->set_name(sigmoid_2_output); |
1614 | *logistic_3_op->add_input() = split_output + ":3" ; |
1615 | (*logistic_3_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1616 | |
1617 | const std::string mul_0_output = base + "mul" ; |
1618 | tensorflow::NodeDef* mul_0_op = tensorflow_graph->add_node(); |
1619 | mul_0_op->set_op("Mul" ); |
1620 | mul_0_op->set_name(mul_0_output); |
1621 | *mul_0_op->add_input() = src_op.inputs[LstmCellOperator::PREV_STATE_INPUT]; |
1622 | *mul_0_op->add_input() = sigmoid_0_output; |
1623 | (*mul_0_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1624 | |
1625 | const std::string add_1_output = |
1626 | src_op.outputs[LstmCellOperator::STATE_OUTPUT]; |
1627 | tensorflow::NodeDef* add_1_op = tensorflow_graph->add_node(); |
1628 | add_1_op->set_op("Add" ); |
1629 | add_1_op->set_name(add_1_output); |
1630 | *add_1_op->add_input() = mul_0_output; |
1631 | *add_1_op->add_input() = mul_1_output; |
1632 | (*add_1_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1633 | |
1634 | const std::string tanh_1_output = base + "Tanh_1" ; |
1635 | tensorflow::NodeDef* tanh_1_op = tensorflow_graph->add_node(); |
1636 | tanh_1_op->set_op("Tanh" ); |
1637 | tanh_1_op->set_name(tanh_1_output); |
1638 | *tanh_1_op->add_input() = add_1_output; |
1639 | (*tanh_1_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1640 | |
1641 | const std::string mul_2_output = |
1642 | src_op.outputs[LstmCellOperator::ACTIV_OUTPUT]; |
1643 | tensorflow::NodeDef* mul_2_op = tensorflow_graph->add_node(); |
1644 | mul_2_op->set_op("Mul" ); |
1645 | mul_2_op->set_name(mul_2_output); |
1646 | *mul_2_op->add_input() = tanh_1_output; |
1647 | *mul_2_op->add_input() = sigmoid_2_output; |
1648 | (*mul_2_op->mutable_attr())["T" ].set_type(DT_FLOAT); |
1649 | } |
1650 | |
1651 | void ConvertSpaceToBatchNDOperator(const Model& model, |
1652 | const SpaceToBatchNDOperator& src_op, |
1653 | GraphDef* tensorflow_graph) { |
1654 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1655 | new_op->set_op("SpaceToBatchND" ); |
1656 | new_op->set_name(src_op.outputs[0]); |
1657 | CHECK_EQ(src_op.inputs.size(), 3); |
1658 | *new_op->add_input() = src_op.inputs[0]; |
1659 | *new_op->add_input() = src_op.inputs[1]; |
1660 | *new_op->add_input() = src_op.inputs[2]; |
1661 | const tensorflow::DataType params_type = |
1662 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1663 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1664 | (*new_op->mutable_attr())["Tblock_shape" ].set_type(DT_INT32); |
1665 | (*new_op->mutable_attr())["Tpaddings" ].set_type(DT_INT32); |
1666 | } |
1667 | |
1668 | void ConvertBatchToSpaceNDOperator(const Model& model, |
1669 | const BatchToSpaceNDOperator& src_op, |
1670 | GraphDef* tensorflow_graph) { |
1671 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1672 | new_op->set_op("BatchToSpaceND" ); |
1673 | new_op->set_name(src_op.outputs[0]); |
1674 | CHECK_EQ(src_op.inputs.size(), 3); |
1675 | *new_op->add_input() = src_op.inputs[0]; |
1676 | *new_op->add_input() = src_op.inputs[1]; |
1677 | *new_op->add_input() = src_op.inputs[2]; |
1678 | const tensorflow::DataType params_type = |
1679 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1680 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1681 | (*new_op->mutable_attr())["Tblock_shape" ].set_type(DT_INT32); |
1682 | (*new_op->mutable_attr())["Tcrops" ].set_type(DT_INT32); |
1683 | } |
1684 | |
1685 | void ConvertPadOperator(const Model& model, const PadOperator& src_op, |
1686 | GraphDef* tensorflow_graph) { |
1687 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1688 | new_op->set_op("Pad" ); |
1689 | new_op->set_name(src_op.outputs[0]); |
1690 | CHECK_EQ(src_op.inputs.size(), 2); |
1691 | *new_op->add_input() = src_op.inputs[0]; |
1692 | *new_op->add_input() = src_op.inputs[1]; |
1693 | |
1694 | const tensorflow::DataType params_type = |
1695 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1696 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1697 | |
1698 | // Create the params tensor. |
1699 | tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); |
1700 | params_op->set_op("Const" ); |
1701 | params_op->set_name(src_op.inputs[1]); |
1702 | (*params_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
1703 | auto* tensor = (*params_op->mutable_attr())["value" ].mutable_tensor(); |
1704 | tensor->set_dtype(DT_INT32); |
1705 | |
1706 | CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); |
1707 | for (int i = 0; i < src_op.left_padding.size(); ++i) { |
1708 | tensor->add_int_val(src_op.left_padding[i]); |
1709 | tensor->add_int_val(src_op.right_padding[i]); |
1710 | } |
1711 | auto* shape = tensor->mutable_tensor_shape(); |
1712 | shape->add_dim()->set_size(src_op.left_padding.size()); |
1713 | shape->add_dim()->set_size(2); |
1714 | } |
1715 | |
1716 | void ConvertPadV2Operator(const Model& model, const PadV2Operator& src_op, |
1717 | GraphDef* tensorflow_graph) { |
1718 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1719 | new_op->set_op("PadV2" ); |
1720 | new_op->set_name(src_op.outputs[0]); |
1721 | CHECK_EQ(src_op.inputs.size(), 2); |
1722 | *new_op->add_input() = src_op.inputs[0]; |
1723 | *new_op->add_input() = src_op.inputs[1]; |
1724 | *new_op->add_input() = src_op.inputs[2]; |
1725 | |
1726 | const tensorflow::DataType params_type = |
1727 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1728 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1729 | |
1730 | // Create the params tensor. |
1731 | tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); |
1732 | params_op->set_op("Const" ); |
1733 | params_op->set_name(src_op.inputs[1]); |
1734 | (*params_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
1735 | auto* tensor = (*params_op->mutable_attr())["value" ].mutable_tensor(); |
1736 | tensor->set_dtype(DT_INT32); |
1737 | |
1738 | CHECK_EQ(src_op.left_padding.size(), src_op.right_padding.size()); |
1739 | for (int i = 0; i < src_op.left_padding.size(); ++i) { |
1740 | tensor->add_int_val(src_op.left_padding[i]); |
1741 | tensor->add_int_val(src_op.right_padding[i]); |
1742 | } |
1743 | auto* shape = tensor->mutable_tensor_shape(); |
1744 | shape->add_dim()->set_size(src_op.left_padding.size()); |
1745 | shape->add_dim()->set_size(2); |
1746 | } |
1747 | |
1748 | void CreateSliceInput(const std::string& input_name, |
1749 | const std::vector<int>& values, |
1750 | GraphDef* tensorflow_graph) { |
1751 | tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); |
1752 | params_op->set_op("Const" ); |
1753 | params_op->set_name(input_name); |
1754 | (*params_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
1755 | auto* tensor = (*params_op->mutable_attr())["value" ].mutable_tensor(); |
1756 | tensor->set_dtype(DT_INT32); |
1757 | |
1758 | for (int i = 0; i < values.size(); ++i) { |
1759 | tensor->add_int_val(values[i]); |
1760 | } |
1761 | auto* shape = tensor->mutable_tensor_shape(); |
1762 | shape->add_dim()->set_size(values.size()); |
1763 | } |
1764 | |
1765 | void ConvertStridedSliceOperator(const Model& model, |
1766 | const StridedSliceOperator& src_op, |
1767 | GraphDef* tensorflow_graph) { |
1768 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1769 | new_op->set_op("StridedSlice" ); |
1770 | new_op->set_name(src_op.outputs[0]); |
1771 | CHECK_EQ(src_op.inputs.size(), 4); |
1772 | *new_op->add_input() = src_op.inputs[0]; |
1773 | *new_op->add_input() = src_op.inputs[1]; |
1774 | *new_op->add_input() = src_op.inputs[2]; |
1775 | *new_op->add_input() = src_op.inputs[3]; |
1776 | |
1777 | const tensorflow::DataType params_type = |
1778 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1779 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1780 | |
1781 | (*new_op->mutable_attr())["Index" ].set_type(DT_INT32); |
1782 | (*new_op->mutable_attr())["begin_mask" ].set_i(src_op.begin_mask); |
1783 | (*new_op->mutable_attr())["ellipsis_mask" ].set_i(src_op.ellipsis_mask); |
1784 | (*new_op->mutable_attr())["end_mask" ].set_i(src_op.end_mask); |
1785 | (*new_op->mutable_attr())["new_axis_mask" ].set_i(src_op.new_axis_mask); |
1786 | (*new_op->mutable_attr())["shrink_axis_mask" ].set_i(src_op.shrink_axis_mask); |
1787 | |
1788 | // Create tensors for start/stop indices and strides. |
1789 | CreateSliceInput(src_op.inputs[1], src_op.start_indices, tensorflow_graph); |
1790 | CreateSliceInput(src_op.inputs[2], src_op.stop_indices, tensorflow_graph); |
1791 | CreateSliceInput(src_op.inputs[3], src_op.strides, tensorflow_graph); |
1792 | } |
1793 | |
1794 | void ConvertSliceOperator(const Model& model, const SliceOperator& src_op, |
1795 | GraphDef* tensorflow_graph) { |
1796 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1797 | new_op->set_op("Slice" ); |
1798 | new_op->set_name(src_op.outputs[0]); |
1799 | CHECK_EQ(src_op.inputs.size(), 3); |
1800 | *new_op->add_input() = src_op.inputs[0]; |
1801 | *new_op->add_input() = src_op.inputs[1]; |
1802 | *new_op->add_input() = src_op.inputs[2]; |
1803 | |
1804 | const tensorflow::DataType params_type = |
1805 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1806 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1807 | (*new_op->mutable_attr())["Index" ].set_type(DT_INT32); |
1808 | |
1809 | // Create tensors for begin and size inputs. |
1810 | CreateSliceInput(src_op.inputs[1], src_op.begin, tensorflow_graph); |
1811 | CreateSliceInput(src_op.inputs[2], src_op.size, tensorflow_graph); |
1812 | } |
1813 | |
1814 | template <typename T> |
1815 | void ConvertReduceOperator(const Model& model, const T& src_op, |
1816 | GraphDef* tensorflow_graph, |
1817 | const std::string& op_name) { |
1818 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1819 | new_op->set_op(op_name); |
1820 | new_op->set_name(src_op.outputs[0]); |
1821 | CHECK_EQ(src_op.inputs.size(), 2); |
1822 | *new_op->add_input() = src_op.inputs[0]; |
1823 | *new_op->add_input() = src_op.inputs[1]; |
1824 | |
1825 | if (src_op.type != OperatorType::kAny) { |
1826 | const tensorflow::DataType params_type = |
1827 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1828 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1829 | } |
1830 | const tensorflow::DataType indices_type = |
1831 | GetTensorFlowDataType(model, src_op.inputs[1]); |
1832 | (*new_op->mutable_attr())["Tidx" ].set_type(indices_type); |
1833 | |
1834 | if (src_op.keep_dims) { |
1835 | (*new_op->mutable_attr())["keep_dims" ].set_b(true); |
1836 | } |
1837 | |
1838 | // Create the params tensor. |
1839 | tensorflow::NodeDef* params_op = tensorflow_graph->add_node(); |
1840 | params_op->set_op("Const" ); |
1841 | params_op->set_name(src_op.inputs[1]); |
1842 | (*params_op->mutable_attr())["dtype" ].set_type(DT_INT32); |
1843 | auto* tensor = (*params_op->mutable_attr())["value" ].mutable_tensor(); |
1844 | tensor->set_dtype(DT_INT32); |
1845 | |
1846 | for (int i = 0; i < src_op.axis.size(); ++i) { |
1847 | tensor->add_int_val(src_op.axis[i]); |
1848 | } |
1849 | auto* shape = tensor->mutable_tensor_shape(); |
1850 | shape->add_dim()->set_size(src_op.axis.size()); |
1851 | } |
1852 | |
1853 | void ConvertSqueezeOperator(const Model& model, const SqueezeOperator& src_op, |
1854 | GraphDef* tensorflow_graph) { |
1855 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1856 | new_op->set_op("Squeeze" ); |
1857 | new_op->set_name(src_op.outputs[0]); |
1858 | CHECK_EQ(src_op.inputs.size(), 1); |
1859 | *new_op->add_input() = src_op.inputs[0]; |
1860 | |
1861 | const tensorflow::DataType params_type = |
1862 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1863 | (*new_op->mutable_attr())["T" ].set_type(params_type); |
1864 | |
1865 | if (!src_op.squeeze_dims.empty()) { |
1866 | auto& squeeze_dims = (*new_op->mutable_attr())["squeeze_dims" ]; |
1867 | for (int i : src_op.squeeze_dims) { |
1868 | squeeze_dims.mutable_list()->add_i(i); |
1869 | } |
1870 | } |
1871 | } |
1872 | |
1873 | void ConvertSubOperator(const Model& model, const SubOperator& src_op, |
1874 | GraphDef* tensorflow_graph) { |
1875 | tensorflow::NodeDef* sub_op = tensorflow_graph->add_node(); |
1876 | sub_op->set_op("Sub" ); |
1877 | sub_op->set_name(src_op.outputs[0]); |
1878 | CHECK_EQ(src_op.inputs.size(), 2); |
1879 | *sub_op->add_input() = src_op.inputs[0]; |
1880 | *sub_op->add_input() = src_op.inputs[1]; |
1881 | const tensorflow::DataType data_type = |
1882 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1883 | (*sub_op->mutable_attr())["T" ].set_type(data_type); |
1884 | } |
1885 | |
1886 | void ConvertTensorFlowMinimumOperator(const Model& model, |
1887 | const TensorFlowMinimumOperator& src_op, |
1888 | GraphDef* tensorflow_graph) { |
1889 | tensorflow::NodeDef* min_op = tensorflow_graph->add_node(); |
1890 | min_op->set_op("Minimum" ); |
1891 | min_op->set_name(src_op.outputs[0]); |
1892 | CHECK_EQ(src_op.inputs.size(), 2); |
1893 | *min_op->add_input() = src_op.inputs[0]; |
1894 | *min_op->add_input() = src_op.inputs[1]; |
1895 | const tensorflow::DataType data_type = |
1896 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1897 | (*min_op->mutable_attr())["T" ].set_type(data_type); |
1898 | } |
1899 | |
1900 | void ConvertTensorFlowMaximumOperator(const Model& model, |
1901 | const TensorFlowMaximumOperator& src_op, |
1902 | GraphDef* tensorflow_graph) { |
1903 | tensorflow::NodeDef* max_op = tensorflow_graph->add_node(); |
1904 | max_op->set_op("Maximum" ); |
1905 | max_op->set_name(src_op.outputs[0]); |
1906 | CHECK_EQ(src_op.inputs.size(), 2); |
1907 | *max_op->add_input() = src_op.inputs[0]; |
1908 | *max_op->add_input() = src_op.inputs[1]; |
1909 | const tensorflow::DataType data_type = |
1910 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1911 | (*max_op->mutable_attr())["T" ].set_type(data_type); |
1912 | } |
1913 | |
1914 | void ConvertSelectOperator(const Model& model, const SelectOperator& src_op, |
1915 | GraphDef* tensorflow_graph) { |
1916 | tensorflow::NodeDef* select_op = tensorflow_graph->add_node(); |
1917 | select_op->set_op("Select" ); |
1918 | select_op->set_name(src_op.outputs[0]); |
1919 | CHECK_EQ(src_op.inputs.size(), 3); |
1920 | *select_op->add_input() = src_op.inputs[0]; |
1921 | *select_op->add_input() = src_op.inputs[1]; |
1922 | *select_op->add_input() = src_op.inputs[2]; |
1923 | const tensorflow::DataType data_type = |
1924 | GetTensorFlowDataType(model, src_op.inputs[1]); |
1925 | (*select_op->mutable_attr())["T" ].set_type(data_type); |
1926 | } |
1927 | |
1928 | void ConvertTileOperator(const Model& model, |
1929 | const TensorFlowTileOperator& src_op, |
1930 | GraphDef* tensorflow_graph) { |
1931 | tensorflow::NodeDef* tile_op = tensorflow_graph->add_node(); |
1932 | tile_op->set_op("Tile" ); |
1933 | tile_op->set_name(src_op.outputs[0]); |
1934 | CHECK_EQ(src_op.inputs.size(), 2); |
1935 | *tile_op->add_input() = src_op.inputs[0]; |
1936 | *tile_op->add_input() = src_op.inputs[1]; |
1937 | const tensorflow::DataType data_type = |
1938 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1939 | (*tile_op->mutable_attr())["T" ].set_type(data_type); |
1940 | const tensorflow::DataType multiples_data_type = |
1941 | GetTensorFlowDataType(model, src_op.inputs[1]); |
1942 | (*tile_op->mutable_attr())["Tmultiples" ].set_type(multiples_data_type); |
1943 | } |
1944 | |
1945 | void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, |
1946 | GraphDef* tensorflow_graph) { |
1947 | tensorflow::NodeDef* topk_op = tensorflow_graph->add_node(); |
1948 | topk_op->set_op("TopKV2" ); |
1949 | topk_op->set_name(src_op.outputs[0]); |
1950 | CHECK_EQ(src_op.inputs.size(), 2); |
1951 | *topk_op->add_input() = src_op.inputs[0]; |
1952 | *topk_op->add_input() = src_op.inputs[1]; |
1953 | const tensorflow::DataType data_type = |
1954 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1955 | (*topk_op->mutable_attr())["T" ].set_type(data_type); |
1956 | (*topk_op->mutable_attr())["sorted" ].set_b(true); |
1957 | } |
1958 | |
1959 | void ConvertRandomUniformOperator(const Model& model, |
1960 | const RandomUniformOperator& src_op, |
1961 | GraphDef* tensorflow_graph) { |
1962 | CHECK(tensorflow_graph != nullptr); |
1963 | tensorflow::NodeDef* new_op = tensorflow_graph->add_node(); |
1964 | new_op->set_op("RandomUniform" ); |
1965 | CHECK_EQ(src_op.inputs.size(), 1); |
1966 | new_op->set_name(src_op.outputs[0]); |
1967 | *new_op->add_input() = src_op.inputs[0]; |
1968 | const tensorflow::DataType shape_type = |
1969 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1970 | (*new_op->mutable_attr())["T" ].set_type(shape_type); |
1971 | (*new_op->mutable_attr())["dtype" ].set_type( |
1972 | GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0])); |
1973 | (*new_op->mutable_attr())["seed" ].set_i(src_op.seed); |
1974 | (*new_op->mutable_attr())["seed2" ].set_i(src_op.seed2); |
1975 | } |
1976 | |
1977 | void ConvertComparisonOperator(const Model& model, const Operator& src_op, |
1978 | const char* op_name, |
1979 | GraphDef* tensorflow_graph) { |
1980 | tensorflow::NodeDef* comparison_op = tensorflow_graph->add_node(); |
1981 | comparison_op->set_op(op_name); |
1982 | comparison_op->set_name(src_op.outputs[0]); |
1983 | CHECK_EQ(src_op.inputs.size(), 2); |
1984 | *comparison_op->add_input() = src_op.inputs[0]; |
1985 | *comparison_op->add_input() = src_op.inputs[1]; |
1986 | const tensorflow::DataType data_type = |
1987 | GetTensorFlowDataType(model, src_op.inputs[0]); |
1988 | (*comparison_op->mutable_attr())["T" ].set_type(data_type); |
1989 | } |
1990 | |
1991 | void ConvertSparseToDenseOperator(const Model& model, |
1992 | const SparseToDenseOperator& src_op, |
1993 | const char* op_name, |
1994 | GraphDef* tensorflow_graph) { |
1995 | tensorflow::NodeDef* sparse_to_dense_op = tensorflow_graph->add_node(); |
1996 | sparse_to_dense_op->set_op(op_name); |
1997 | sparse_to_dense_op->set_name(src_op.outputs[0]); |
1998 | CHECK_EQ(src_op.inputs.size(), 4); |
1999 | for (int i = 0; i < 4; ++i) { |
2000 | *sparse_to_dense_op->add_input() = src_op.inputs[i]; |
2001 | } |
2002 | const tensorflow::DataType data_type = |
2003 | GetTensorFlowDataType(model, src_op.inputs[3]); |
2004 | (*sparse_to_dense_op->mutable_attr())["T" ].set_type(data_type); |
2005 | const tensorflow::DataType index_type = |
2006 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2007 | (*sparse_to_dense_op->mutable_attr())["Tindices" ].set_type(index_type); |
2008 | (*sparse_to_dense_op->mutable_attr())["Tindices" ].set_b( |
2009 | src_op.validate_indices); |
2010 | } |
2011 | |
2012 | void ConvertPowOperator(const Model& model, const PowOperator& src_op, |
2013 | const char* op_name, GraphDef* tensorflow_graph) { |
2014 | tensorflow::NodeDef* pow_op = tensorflow_graph->add_node(); |
2015 | pow_op->set_op(op_name); |
2016 | pow_op->set_name(src_op.outputs[0]); |
2017 | CHECK_EQ(src_op.inputs.size(), 2); |
2018 | for (int i = 0; i < 2; ++i) { |
2019 | *pow_op->add_input() = src_op.inputs[i]; |
2020 | } |
2021 | const tensorflow::DataType data_type = |
2022 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2023 | (*pow_op->mutable_attr())["T" ].set_type(data_type); |
2024 | } |
2025 | |
2026 | void ConvertLogicalAndOperator(const Model& model, |
2027 | const LogicalAndOperator& src_op, |
2028 | GraphDef* tensorflow_graph) { |
2029 | tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); |
2030 | logical_op->set_op("LogicalAnd" ); |
2031 | logical_op->set_name(src_op.outputs[0]); |
2032 | CHECK_EQ(src_op.inputs.size(), 2); |
2033 | for (int i = 0; i < 2; ++i) { |
2034 | *logical_op->add_input() = src_op.inputs[i]; |
2035 | } |
2036 | } |
2037 | |
2038 | void ConvertLogicalNotOperator(const Model& model, |
2039 | const LogicalNotOperator& src_op, |
2040 | GraphDef* tensorflow_graph) { |
2041 | tensorflow::NodeDef* logical_op = tensorflow_graph->add_node(); |
2042 | logical_op->set_op("LogicalNot" ); |
2043 | logical_op->set_name(src_op.outputs[0]); |
2044 | CHECK_EQ(src_op.inputs.size(), 1); |
2045 | *logical_op->add_input() = src_op.inputs[0]; |
2046 | } |
2047 | |
2048 | void ConvertLogicalOrOperator(const Model& model, |
2049 | const LogicalOrOperator& src_op, |
2050 | const char* op_name, GraphDef* tensorflow_graph) { |
2051 | tensorflow::NodeDef* logical_or_op = tensorflow_graph->add_node(); |
2052 | logical_or_op->set_op(op_name); |
2053 | logical_or_op->set_name(src_op.outputs[0]); |
2054 | CHECK_EQ(src_op.inputs.size(), 2); |
2055 | for (int i = 0; i < 2; ++i) { |
2056 | *logical_or_op->add_input() = src_op.inputs[i]; |
2057 | } |
2058 | const tensorflow::DataType data_type = |
2059 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2060 | (*logical_or_op->mutable_attr())["T" ].set_type(data_type); |
2061 | } |
2062 | |
2063 | void ConvertCTCBeamSearchDecoderOperator( |
2064 | const Model& model, const CTCBeamSearchDecoderOperator& src_op, |
2065 | const char* op_name, GraphDef* tensorflow_graph) { |
2066 | auto* op = tensorflow_graph->add_node(); |
2067 | op->set_op(op_name); |
2068 | op->set_name(src_op.outputs[0]); |
2069 | CHECK_EQ(src_op.inputs.size(), 2); |
2070 | for (int i = 0; i < 2; ++i) { |
2071 | *op->add_input() = src_op.inputs[i]; |
2072 | } |
2073 | (*op->mutable_attr())["beam_width" ].set_i(src_op.beam_width); |
2074 | (*op->mutable_attr())["top_paths" ].set_i(src_op.top_paths); |
2075 | (*op->mutable_attr())["merge_repeated" ].set_b(src_op.merge_repeated); |
2076 | } |
2077 | |
2078 | void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op, |
2079 | const char* op_name, GraphDef* tensorflow_graph) { |
2080 | tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node(); |
2081 | unpack_op->set_op(op_name); |
2082 | unpack_op->set_name(src_op.outputs[0]); |
2083 | CHECK_EQ(src_op.inputs.size(), 2); |
2084 | *unpack_op->add_input() = src_op.inputs[0]; |
2085 | const tensorflow::DataType data_type = |
2086 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2087 | (*unpack_op->mutable_attr())["T" ].set_type(data_type); |
2088 | (*unpack_op->mutable_attr())["num" ].set_i(src_op.num); |
2089 | (*unpack_op->mutable_attr())["axis" ].set_i(src_op.axis); |
2090 | } |
2091 | |
2092 | void ConvertZerosLikeOperator(const Model& model, |
2093 | const TensorFlowZerosLikeOperator& src_op, |
2094 | const char* op_name, GraphDef* tensorflow_graph) { |
2095 | tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node(); |
2096 | zeros_like_op->set_op(op_name); |
2097 | zeros_like_op->set_name(src_op.outputs[0]); |
2098 | DCHECK_EQ(src_op.inputs.size(), 1); |
2099 | *zeros_like_op->add_input() = src_op.inputs[0]; |
2100 | const tensorflow::DataType data_type = |
2101 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2102 | (*zeros_like_op->mutable_attr())["T" ].set_type(data_type); |
2103 | } |
2104 | |
2105 | void ConvertReverseV2Operator(const Model& model, |
2106 | const ReverseV2Operator& src_op, |
2107 | const char* op_name, GraphDef* tensorflow_graph) { |
2108 | tensorflow::NodeDef* reverse_v2_op = tensorflow_graph->add_node(); |
2109 | reverse_v2_op->set_op(op_name); |
2110 | reverse_v2_op->set_name(src_op.outputs[0]); |
2111 | DCHECK_EQ(src_op.inputs.size(), 2); |
2112 | *reverse_v2_op->add_input() = src_op.inputs[0]; |
2113 | *reverse_v2_op->add_input() = src_op.inputs[1]; |
2114 | const tensorflow::DataType data_type = |
2115 | GetTensorFlowDataType(model, src_op.inputs[0]); |
2116 | (*reverse_v2_op->mutable_attr())["T" ].set_type(data_type); |
2117 | } |
2118 | |
2119 | void ConvertReverseSequenceOperator(const Model& model, |
2120 | const ReverseSequenceOperator& src_op, |
2121 | GraphDef* tensorflow_graph) { |
2122 | tensorflow::NodeDef* reverse_seq_op = tensorflow_graph->add_node(); |
2123 | reverse_seq_op->set_op("ReverseSequence" ); |
2124 | reverse_seq_op->set_name(src_op.outputs[0]); |
2125 | CHECK_EQ(src_op.inputs.size(), 2); |
2126 | *reverse_seq_op->add_input() = src_op.inputs[0]; |
2127 | *reverse_seq_op->add_input() = src_op.inputs[1]; |
2128 | (*reverse_seq_op->mutable_attr())["seq_dim" ].set_i(src_op.seq_dim); |
2129 | (*reverse_seq_op->mutable_attr())["batch_dim" ].set_i(src_op.batch_dim); |
2130 | } |
2131 | |
2132 | void ConvertOperator(const Model& model, const Operator& src_op, |
2133 | GraphDef* tensorflow_graph) { |
2134 | if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { |
2135 | LOG(FATAL) |
2136 | << "Unsupported: the input model has a fused activation function" ; |
2137 | } |
2138 | |
2139 | if (src_op.type == OperatorType::kConv) { |
2140 | ConvertConvOperator(model, static_cast<const ConvOperator&>(src_op), |
2141 | tensorflow_graph); |
2142 | } else if (src_op.type == OperatorType::kDepthwiseConv) { |
2143 | ConvertDepthwiseConvOperator( |
2144 | model, static_cast<const DepthwiseConvOperator&>(src_op), |
2145 | tensorflow_graph); |
2146 | } else if (src_op.type == OperatorType::kDepthToSpace) { |
2147 | ConvertDepthToSpaceOperator( |
2148 | model, static_cast<const DepthToSpaceOperator&>(src_op), |
2149 | tensorflow_graph); |
2150 | } else if (src_op.type == OperatorType::kSpaceToDepth) { |
2151 | ConvertSpaceToDepthOperator( |
2152 | model, static_cast<const SpaceToDepthOperator&>(src_op), |
2153 | tensorflow_graph); |
2154 | } else if (src_op.type == OperatorType::kFullyConnected) { |
2155 | ConvertFullyConnectedOperator( |
2156 | model, static_cast<const FullyConnectedOperator&>(src_op), |
2157 | tensorflow_graph); |
2158 | } else if (src_op.type == OperatorType::kAdd) { |
2159 | ConvertAddOperator(model, static_cast<const AddOperator&>(src_op), |
2160 | tensorflow_graph); |
2161 | } else if (src_op.type == OperatorType::kAddN) { |
2162 | ConvertAddNOperator(model, static_cast<const AddNOperator&>(src_op), |
2163 | tensorflow_graph); |
2164 | } else if (src_op.type == OperatorType::kMul) { |
2165 | ConvertMulOperator(model, static_cast<const MulOperator&>(src_op), |
2166 | tensorflow_graph); |
2167 | } else if (src_op.type == OperatorType::kDiv) { |
2168 | ConvertDivOperator(model, static_cast<const DivOperator&>(src_op), |
2169 | tensorflow_graph); |
2170 | } else if (src_op.type == OperatorType::kRelu) { |
2171 | ConvertReluOperator(model, static_cast<const ReluOperator&>(src_op), |
2172 | tensorflow_graph); |
2173 | } else if (src_op.type == OperatorType::kRelu1) { |
2174 | ConvertRelu1Operator(static_cast<const Relu1Operator&>(src_op), |
2175 | tensorflow_graph); |
2176 | } else if (src_op.type == OperatorType::kRelu6) { |
2177 | ConvertRelu6Operator(static_cast<const Relu6Operator&>(src_op), |
2178 | tensorflow_graph); |
2179 | } else if (src_op.type == OperatorType::kLog) { |
2180 | ConvertLogOperator(static_cast<const LogOperator&>(src_op), |
2181 | tensorflow_graph); |
2182 | } else if (src_op.type == OperatorType::kLogistic) { |
2183 | ConvertLogisticOperator(static_cast<const LogisticOperator&>(src_op), |
2184 | tensorflow_graph); |
2185 | } else if (src_op.type == OperatorType::kTanh) { |
2186 | ConvertTanhOperator(static_cast<const TanhOperator&>(src_op), |
2187 | tensorflow_graph); |
2188 | } else if (src_op.type == OperatorType::kL2Normalization) { |
2189 | ConvertL2NormalizationOperator( |
2190 | static_cast<const L2NormalizationOperator&>(src_op), tensorflow_graph); |
2191 | } else if (src_op.type == OperatorType::kSoftmax) { |
2192 | ConvertSoftmaxOperator(model, static_cast<const SoftmaxOperator&>(src_op), |
2193 | tensorflow_graph); |
2194 | } else if (src_op.type == OperatorType::kLogSoftmax) { |
2195 | ConvertLogSoftmaxOperator(model, |
2196 | static_cast<const LogSoftmaxOperator&>(src_op), |
2197 | tensorflow_graph); |
2198 | } else if (src_op.type == OperatorType::kLocalResponseNormalization) { |
2199 | ConvertLocalResponseNormalizationOperator( |
2200 | static_cast<const LocalResponseNormalizationOperator&>(src_op), |
2201 | tensorflow_graph); |
2202 | } else if (src_op.type == OperatorType::kLstmCell) { |
2203 | ConvertLstmCellOperator(model, static_cast<const LstmCellOperator&>(src_op), |
2204 | tensorflow_graph); |
2205 | } else if (src_op.type == OperatorType::kMaxPool) { |
2206 | ConvertMaxPoolOperator(static_cast<const MaxPoolOperator&>(src_op), |
2207 | tensorflow_graph); |
2208 | } else if (src_op.type == OperatorType::kAveragePool) { |
2209 | ConvertAveragePoolOperator(static_cast<const AveragePoolOperator&>(src_op), |
2210 | tensorflow_graph); |
2211 | } else if (src_op.type == OperatorType::kConcatenation) { |
2212 | ConvertConcatenationOperator( |
2213 | model, static_cast<const ConcatenationOperator&>(src_op), |
2214 | tensorflow_graph); |
2215 | } else if (src_op.type == OperatorType::kReshape) { |
2216 | ConvertTensorFlowReshapeOperator( |
2217 | model, static_cast<const TensorFlowReshapeOperator&>(src_op), |
2218 | tensorflow_graph); |
2219 | } else if (src_op.type == OperatorType::kL2Pool) { |
2220 | ConvertL2PoolOperator(static_cast<const L2PoolOperator&>(src_op), |
2221 | tensorflow_graph); |
2222 | } else if (src_op.type == OperatorType::kSquare) { |
2223 | ConvertSquareOperator(static_cast<const TensorFlowSquareOperator&>(src_op), |
2224 | tensorflow_graph); |
2225 | } else if (src_op.type == OperatorType::kSqrt) { |
2226 | ConvertSqrtOperator(static_cast<const TensorFlowSqrtOperator&>(src_op), |
2227 | tensorflow_graph); |
2228 | } else if (src_op.type == OperatorType::kRsqrt) { |
2229 | ConvertRsqrtOperator(model, |
2230 | static_cast<const TensorFlowRsqrtOperator&>(src_op), |
2231 | tensorflow_graph); |
2232 | } else if (src_op.type == OperatorType::kSplit) { |
2233 | ConvertSplitOperator(model, |
2234 | static_cast<const TensorFlowSplitOperator&>(src_op), |
2235 | tensorflow_graph); |
2236 | } else if (src_op.type == OperatorType::kSplitV) { |
2237 | ConvertSplitVOperator(model, |
2238 | static_cast<const TensorFlowSplitVOperator&>(src_op), |
2239 | tensorflow_graph); |
2240 | } else if (src_op.type == OperatorType::kFakeQuant) { |
2241 | ConvertFakeQuantOperator(static_cast<const FakeQuantOperator&>(src_op), |
2242 | tensorflow_graph); |
2243 | } else if (src_op.type == OperatorType::kCast) { |
2244 | ConvertCastOperator(model, static_cast<const CastOperator&>(src_op), |
2245 | tensorflow_graph); |
2246 | } else if (src_op.type == OperatorType::kFloor) { |
2247 | ConvertFloorOperator(model, static_cast<const FloorOperator&>(src_op), |
2248 | tensorflow_graph); |
2249 | } else if (src_op.type == OperatorType::kCeil) { |
2250 | ConvertCeilOperator(model, static_cast<const CeilOperator&>(src_op), |
2251 | tensorflow_graph); |
2252 | } else if (src_op.type == OperatorType::kRound) { |
2253 | ConvertRoundOperator(model, static_cast<const RoundOperator&>(src_op), |
2254 | tensorflow_graph); |
2255 | } else if (src_op.type == OperatorType::kGather) { |
2256 | ConvertGatherOperator(model, static_cast<const GatherOperator&>(src_op), |
2257 | tensorflow_graph); |
2258 | } else if (src_op.type == OperatorType::kResizeBilinear) { |
2259 | ConvertResizeBilinearOperator( |
2260 | model, static_cast<const ResizeBilinearOperator&>(src_op), |
2261 | tensorflow_graph); |
2262 | } else if (src_op.type == OperatorType::kResizeNearestNeighbor) { |
2263 | ConvertResizeNearestNeighborOperator( |
2264 | model, static_cast<const ResizeNearestNeighborOperator&>(src_op), |
2265 | tensorflow_graph); |
2266 | } else if (src_op.type == OperatorType::kSpaceToBatchND) { |
2267 | ConvertSpaceToBatchNDOperator( |
2268 | model, static_cast<const SpaceToBatchNDOperator&>(src_op), |
2269 | tensorflow_graph); |
2270 | } else if (src_op.type == OperatorType::kBatchToSpaceND) { |
2271 | ConvertBatchToSpaceNDOperator( |
2272 | model, static_cast<const BatchToSpaceNDOperator&>(src_op), |
2273 | tensorflow_graph); |
2274 | } else if (src_op.type == OperatorType::kPad) { |
2275 | ConvertPadOperator(model, static_cast<const PadOperator&>(src_op), |
2276 | tensorflow_graph); |
2277 | } else if (src_op.type == OperatorType::kPadV2) { |
2278 | ConvertPadV2Operator(model, static_cast<const PadV2Operator&>(src_op), |
2279 | tensorflow_graph); |
2280 | } else if (src_op.type == OperatorType::kStridedSlice) { |
2281 | ConvertStridedSliceOperator( |
2282 | model, static_cast<const StridedSliceOperator&>(src_op), |
2283 | tensorflow_graph); |
2284 | } else if (src_op.type == OperatorType::kMean) { |
2285 | ConvertReduceOperator(model, static_cast<const MeanOperator&>(src_op), |
2286 | tensorflow_graph, "Mean" ); |
2287 | } else if (src_op.type == OperatorType::kSum) { |
2288 | ConvertReduceOperator(model, |
2289 | static_cast<const TensorFlowSumOperator&>(src_op), |
2290 | tensorflow_graph, "Sum" ); |
2291 | } else if (src_op.type == OperatorType::kReduceProd) { |
2292 | ConvertReduceOperator(model, |
2293 | static_cast<const TensorFlowProdOperator&>(src_op), |
2294 | tensorflow_graph, "Prod" ); |
2295 | } else if (src_op.type == OperatorType::kReduceMin) { |
2296 | ConvertReduceOperator(model, |
2297 | static_cast<const TensorFlowMinOperator&>(src_op), |
2298 | tensorflow_graph, "Min" ); |
2299 | } else if (src_op.type == OperatorType::kReduceMax) { |
2300 | ConvertReduceOperator(model, |
2301 | static_cast<const TensorFlowMaxOperator&>(src_op), |
2302 | tensorflow_graph, "Max" ); |
2303 | } else if (src_op.type == OperatorType::kSub) { |
2304 | ConvertSubOperator(model, static_cast<const SubOperator&>(src_op), |
2305 | tensorflow_graph); |
2306 | } else if (src_op.type == OperatorType::kMinimum) { |
2307 | ConvertTensorFlowMinimumOperator( |
2308 | model, static_cast<const TensorFlowMinimumOperator&>(src_op), |
2309 | tensorflow_graph); |
2310 | } else if (src_op.type == OperatorType::kMaximum) { |
2311 | ConvertTensorFlowMaximumOperator( |
2312 | model, static_cast<const TensorFlowMaximumOperator&>(src_op), |
2313 | tensorflow_graph); |
2314 | } else if (src_op.type == OperatorType::kSqueeze) { |
2315 | ConvertSqueezeOperator(model, static_cast<const SqueezeOperator&>(src_op), |
2316 | tensorflow_graph); |
2317 | } else if (src_op.type == OperatorType::kSlice) { |
2318 | ConvertSliceOperator(model, static_cast<const SliceOperator&>(src_op), |
2319 | tensorflow_graph); |
2320 | } else if (src_op.type == OperatorType::kArgMax) { |
2321 | ConvertArgMaxOperator(model, static_cast<const ArgMaxOperator&>(src_op), |
2322 | tensorflow_graph); |
2323 | } else if (src_op.type == OperatorType::kArgMin) { |
2324 | ConvertArgMinOperator(model, static_cast<const ArgMinOperator&>(src_op), |
2325 | tensorflow_graph); |
2326 | } else if (src_op.type == OperatorType::kTopK_V2) { |
2327 | ConvertTopKV2Operator(model, static_cast<const TopKV2Operator&>(src_op), |
2328 | tensorflow_graph); |
2329 | } else if (src_op.type == OperatorType::kTranspose) { |
2330 | ConvertTransposeOperator( |
2331 | model, static_cast<const TransposeOperator&>(src_op), tensorflow_graph); |
2332 | } else if (src_op.type == OperatorType::kShape) { |
2333 | ConvertTensorFlowShapeOperator( |
2334 | model, static_cast<const TensorFlowShapeOperator&>(src_op), |
2335 | tensorflow_graph); |
2336 | } else if (src_op.type == OperatorType::kRank) { |
2337 | ConvertRankOperator(model, |
2338 | static_cast<const TensorFlowRankOperator&>(src_op), |
2339 | tensorflow_graph); |
2340 | } else if (src_op.type == OperatorType::kRange) { |
2341 | ConvertRangeOperator(model, static_cast<const RangeOperator&>(src_op), |
2342 | tensorflow_graph); |
2343 | } else if (src_op.type == OperatorType::kPack) { |
2344 | ConvertPackOperator(model, static_cast<const PackOperator&>(src_op), |
2345 | tensorflow_graph); |
2346 | } else if (src_op.type == OperatorType::kFill) { |
2347 | ConvertFillOperator(model, static_cast<const FillOperator&>(src_op), |
2348 | tensorflow_graph); |
2349 | } else if (src_op.type == OperatorType::kFloorDiv) { |
2350 | ConvertFloorDivOperator(model, static_cast<const FloorDivOperator&>(src_op), |
2351 | tensorflow_graph); |
2352 | } else if (src_op.type == OperatorType::kFloorMod) { |
2353 | ConvertFloorModOperator(model, static_cast<const FloorModOperator&>(src_op), |
2354 | tensorflow_graph); |
2355 | } else if (src_op.type == OperatorType::kExpandDims) { |
2356 | ConvertExpandDimsOperator(model, |
2357 | static_cast<const ExpandDimsOperator&>(src_op), |
2358 | tensorflow_graph); |
2359 | } else if (src_op.type == OperatorType::kTransposeConv) { |
2360 | ConvertTransposeConvOperator( |
2361 | model, static_cast<const TransposeConvOperator&>(src_op), |
2362 | tensorflow_graph); |
2363 | } else if (src_op.type == OperatorType::kRandomUniform) { |
2364 | ConvertRandomUniformOperator( |
2365 | model, static_cast<const RandomUniformOperator&>(src_op), |
2366 | tensorflow_graph); |
2367 | } else if (src_op.type == OperatorType::kEqual) { |
2368 | ConvertComparisonOperator(model, src_op, "Equal" , tensorflow_graph); |
2369 | } else if (src_op.type == OperatorType::kNotEqual) { |
2370 | ConvertComparisonOperator(model, src_op, "NotEqual" , tensorflow_graph); |
2371 | } else if (src_op.type == OperatorType::kGreater) { |
2372 | ConvertComparisonOperator(model, src_op, "Greater" , tensorflow_graph); |
2373 | } else if (src_op.type == OperatorType::kGreaterEqual) { |
2374 | ConvertComparisonOperator(model, src_op, "GreaterEqual" , tensorflow_graph); |
2375 | } else if (src_op.type == OperatorType::kLess) { |
2376 | ConvertComparisonOperator(model, src_op, "Less" , tensorflow_graph); |
2377 | } else if (src_op.type == OperatorType::kLessEqual) { |
2378 | ConvertComparisonOperator(model, src_op, "LessEqual" , tensorflow_graph); |
2379 | } else if (src_op.type == OperatorType::kSelect) { |
2380 | ConvertSelectOperator(model, static_cast<const SelectOperator&>(src_op), |
2381 | tensorflow_graph); |
2382 | } else if (src_op.type == OperatorType::kTile) { |
2383 | ConvertTileOperator(model, |
2384 | static_cast<const TensorFlowTileOperator&>(src_op), |
2385 | tensorflow_graph); |
2386 | } else if (src_op.type == OperatorType::kPow) { |
2387 | ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow" , |
2388 | tensorflow_graph); |
2389 | } else if (src_op.type == OperatorType::kAny) { |
2390 | ConvertReduceOperator(model, |
2391 | static_cast<const TensorFlowAnyOperator&>(src_op), |
2392 | tensorflow_graph, "Any" ); |
2393 | } else if (src_op.type == OperatorType::kLogicalAnd) { |
2394 | ConvertLogicalAndOperator(model, |
2395 | static_cast<const LogicalAndOperator&>(src_op), |
2396 | tensorflow_graph); |
2397 | } else if (src_op.type == OperatorType::kLogicalNot) { |
2398 | ConvertLogicalNotOperator(model, |
2399 | static_cast<const LogicalNotOperator&>(src_op), |
2400 | tensorflow_graph); |
2401 | } else if (src_op.type == OperatorType::kOneHot) { |
2402 | ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op), |
2403 | tensorflow_graph); |
2404 | } else if (src_op.type == OperatorType::kLogicalOr) { |
2405 | ConvertLogicalOrOperator(model, |
2406 | static_cast<const LogicalOrOperator&>(src_op), |
2407 | "LogicalOr" , tensorflow_graph); |
2408 | } else if (src_op.type == OperatorType::kCTCBeamSearchDecoder) { |
2409 | ConvertCTCBeamSearchDecoderOperator( |
2410 | model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op), |
2411 | "CTCBeamSearchDecoder" , tensorflow_graph); |
2412 | } else if (src_op.type == OperatorType::kUnpack) { |
2413 | ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op), |
2414 | "Unpack" , tensorflow_graph); |
2415 | } else if (src_op.type == OperatorType::kZerosLike) { |
2416 | ConvertZerosLikeOperator( |
2417 | model, static_cast<const TensorFlowZerosLikeOperator&>(src_op), |
2418 | "ZerosLike" , tensorflow_graph); |
2419 | } else if (src_op.type == OperatorType::kReverseV2) { |
2420 | ConvertReverseV2Operator(model, |
2421 | static_cast<const ReverseV2Operator&>(src_op), |
2422 | "Reverse_V2" , tensorflow_graph); |
2423 | } else if (src_op.type == OperatorType::kReverseSequence) { |
2424 | ConvertReverseSequenceOperator( |
2425 | model, static_cast<const ReverseSequenceOperator&>(src_op), |
2426 | tensorflow_graph); |
2427 | } else { |
2428 | LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); |
2429 | } |
2430 | } |
2431 | |
2432 | void AddPlaceholder(const std::string& name, ArrayDataType type, |
2433 | GraphDef* tensorflow_graph) { |
2434 | tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); |
2435 | placeholder->set_op("Placeholder" ); |
2436 | switch (type) { |
2437 | case ArrayDataType::kBool: |
2438 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_BOOL); |
2439 | break; |
2440 | case ArrayDataType::kFloat: |
2441 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
2442 | break; |
2443 | case ArrayDataType::kUint8: |
2444 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_UINT8); |
2445 | break; |
2446 | case ArrayDataType::kInt32: |
2447 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_INT32); |
2448 | break; |
2449 | case ArrayDataType::kUint32: |
2450 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_UINT32); |
2451 | break; |
2452 | case ArrayDataType::kInt64: |
2453 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_INT64); |
2454 | break; |
2455 | case ArrayDataType::kInt16: |
2456 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_INT16); |
2457 | break; |
2458 | case ArrayDataType::kComplex64: |
2459 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_COMPLEX64); |
2460 | break; |
2461 | default: |
2462 | LOG(FATAL) << "Unexpected data type in array \"" << name << "\"" ; |
2463 | } |
2464 | placeholder->set_name(name); |
2465 | } |
2466 | |
2467 | void AddPlaceholderForRNNState(const Model& model, const std::string& name, |
2468 | int size, GraphDef* tensorflow_graph) { |
2469 | tensorflow::NodeDef* placeholder = tensorflow_graph->add_node(); |
2470 | placeholder->set_op("Placeholder" ); |
2471 | placeholder->set_name(name); |
2472 | (*placeholder->mutable_attr())["dtype" ].set_type(DT_FLOAT); |
2473 | |
2474 | auto* shape = (*placeholder->mutable_attr())["shape" ].mutable_shape(); |
2475 | const auto& state_array = model.GetArray(name); |
2476 | if (state_array.has_shape()) { |
2477 | const auto& state_shape = state_array.shape(); |
2478 | const int kDims = state_shape.dimensions_count(); |
2479 | for (int i = 0; i < kDims; ++i) { |
2480 | shape->add_dim()->set_size(state_shape.dims(i)); |
2481 | } |
2482 | } else { |
2483 | shape->add_dim()->set_size(1); |
2484 | shape->add_dim()->set_size(size); |
2485 | } |
2486 | } |
2487 | |
2488 | void ExportTensorFlowGraphDefImplementation(const Model& model, |
2489 | GraphDef* tensorflow_graph) { |
2490 | for (const auto& input_array : model.flags.input_arrays()) { |
2491 | AddPlaceholder(input_array.name(), |
2492 | model.GetArray(input_array.name()).data_type, |
2493 | tensorflow_graph); |
2494 | } |
2495 | for (const auto& rnn_state : model.flags.rnn_states()) { |
2496 | AddPlaceholderForRNNState(model, rnn_state.state_array(), rnn_state.size(), |
2497 | tensorflow_graph); |
2498 | } |
2499 | for (const auto& op : model.operators) { |
2500 | ConvertOperator(model, *op, tensorflow_graph); |
2501 | } |
2502 | // Generically export arrays that haven't been exported already |
2503 | // by the above operators export. It's important that this comes |
2504 | // after, as some operators need to export arrays that they reference |
2505 | // in a specific way, rather than in the generic way done below. |
2506 | for (const auto& array_pair : model.GetArrayMap()) { |
2507 | const std::string& array_name = array_pair.first; |
2508 | const auto& array = *array_pair.second; |
2509 | if (array.buffer) { |
2510 | switch (array.data_type) { |
2511 | case ArrayDataType::kBool: |
2512 | ConvertBoolTensorConst(model, array_name, tensorflow_graph); |
2513 | break; |
2514 | case ArrayDataType::kFloat: |
2515 | ConvertFloatTensorConst(model, array_name, tensorflow_graph); |
2516 | break; |
2517 | case ArrayDataType::kInt32: |
2518 | ConvertIntTensorConst(model, array_name, tensorflow_graph); |
2519 | break; |
2520 | case ArrayDataType::kComplex64: |
2521 | ConvertComplex64TensorConst(model, array_name, tensorflow_graph); |
2522 | break; |
2523 | default: |
2524 | break; |
2525 | } |
2526 | } |
2527 | } |
2528 | } |
2529 | } // namespace |
2530 | |
2531 | void EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(Model* model) { |
2532 | for (const auto& array_kv : model->GetArrayMap()) { |
2533 | const std::string& array_name = array_kv.first; |
2534 | Array& array = *array_kv.second; |
2535 | if (!array.buffer || !array.minmax) { |
2536 | continue; |
2537 | } |
2538 | const std::string& wrapped_array_name = |
2539 | AvailableArrayName(*model, array_name + "/data" ); |
2540 | Array& wrapped_array = model->GetOrCreateArray(wrapped_array_name); |
2541 | wrapped_array.data_type = array.data_type; |
2542 | wrapped_array.copy_shape(array.shape()); |
2543 | wrapped_array.buffer = std::move(array.buffer); |
2544 | FakeQuantOperator* fakequant_op = new FakeQuantOperator; |
2545 | fakequant_op->inputs = {wrapped_array_name}; |
2546 | fakequant_op->outputs = {array_name}; |
2547 | fakequant_op->minmax = std::make_unique<MinMax>(); |
2548 | *fakequant_op->minmax = *array.minmax; |
2549 | const auto& it = FindOpWithInput(*model, array_name); |
2550 | model->operators.emplace(it, fakequant_op); |
2551 | } |
2552 | CheckInvariants(*model); |
2553 | } |
2554 | |
2555 | void ExportTensorFlowGraphDef(const Model& model, |
2556 | std::string* output_file_contents) { |
2557 | CHECK(output_file_contents->empty()); |
2558 | GraphDef tensorflow_graph; |
2559 | ExportTensorFlowGraphDefImplementation(model, &tensorflow_graph); |
2560 | LogDumpGraphDef(kLogLevelModelChanged, "AT EXPORT" , tensorflow_graph); |
2561 | CHECK(tensorflow_graph.SerializeToString(output_file_contents)); |
2562 | } |
2563 | } // namespace toco |
2564 | |