1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/toco_tooling.h" |
16 | |
17 | #include <cstdlib> |
18 | #include <memory> |
19 | #include <set> |
20 | #include <string> |
21 | |
22 | #include "absl/memory/memory.h" |
23 | #include "absl/strings/str_join.h" |
24 | #include "tensorflow/core/platform/logging.h" |
25 | #include "tensorflow/lite/toco/allocate_transient_arrays.h" |
26 | #include "tensorflow/lite/toco/dump_graphviz.h" |
27 | #include "tensorflow/lite/toco/export_tensorflow.h" |
28 | #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" |
29 | #include "tensorflow/lite/toco/import_tensorflow.h" |
30 | #include "tensorflow/lite/toco/model.h" |
31 | #include "tensorflow/lite/toco/model_flags.pb.h" |
32 | #include "tensorflow/lite/toco/tflite/export.h" |
33 | #include "tensorflow/lite/toco/tflite/import.h" |
34 | #include "tensorflow/lite/toco/toco_flags.pb.h" |
35 | #include "tensorflow/lite/toco/tooling_util.h" |
36 | |
37 | namespace toco { |
38 | namespace { |
39 | // CHECK-fails if the model contains a kUnsupported operation. |
40 | void CheckUnsupportedOperations(const Model& model) { |
41 | std::set<std::string> unsupported_ops; |
42 | for (auto& op : model.operators) { |
43 | if (op->type == OperatorType::kUnsupported) { |
44 | unsupported_ops.insert( |
45 | static_cast<const TensorFlowUnsupportedOperator*>(op.get()) |
46 | ->tensorflow_op); |
47 | } |
48 | } |
49 | QCHECK(unsupported_ops.empty()) |
50 | << "These unsupported ops were not removed by graph transformations: " |
51 | << absl::StrJoin(unsupported_ops, ", " ); |
52 | } |
53 | |
54 | void MakeGeneralGraphTransformationsSet( |
55 | GraphTransformationsSet* transformations) { |
56 | CHECK(transformations->empty()); |
57 | transformations->Add(new ConvertExpandDimsToReshape); |
58 | transformations->Add(new ConvertMatrixDiagV2OrV3ToV1); |
59 | transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1); |
60 | transformations->Add(new ConvertSqueezeToReshape); |
61 | transformations->Add(new ConvertTrivialAddNToAdd); |
62 | transformations->Add(new ConvertTrivialPackToReshape); |
63 | transformations->Add(new ConvertTrivialTileToConcat); |
64 | transformations->Add(new ConvertTrivialTransposeToReshape); |
65 | transformations->Add(new ConvertReorderAxes); |
66 | transformations->Add(new ResolveReshapeAttributes); |
67 | transformations->Add(new ResolveTransposeAttributes); |
68 | transformations->Add(new PropagateActivationFunctionIntoConstants); |
69 | transformations->Add(new PropagateArrayDataTypes); |
70 | transformations->Add(new PropagateFixedSizes); |
71 | transformations->Add(new RemoveSuccessiveTranspose); |
72 | transformations->Add(new RemoveTensorFlowAssert); |
73 | transformations->Add(new RemoveTensorFlowIdentity); |
74 | transformations->Add(new RemoveTrivialConcatenation); |
75 | transformations->Add(new RemoveTrivialConcatenationInput); |
76 | transformations->Add(new RemoveTrivialFakeQuant); |
77 | transformations->Add(new RemoveTrivialSlice); |
78 | transformations->Add(new RemoveUnusedOp); |
79 | transformations->Add(new EnsureBiasVectors); |
80 | transformations->Add(new ResolveReorderAxes); |
81 | transformations->Add(new UnrollBatchMatMul); |
82 | transformations->Add(new ResolveTensorFlowMatMul); |
83 | transformations->Add(new FuseBinaryIntoPrecedingAffine); |
84 | transformations->Add(new FuseBinaryIntoFollowingAffine); |
85 | transformations->Add(new FuseBroadcastIntoFollowingBinary); |
86 | transformations->Add(new MergeReshapeIntoPrecedingTranspose); |
87 | transformations->Add(new MoveBinaryOperatorBeforeReshape); |
88 | transformations->Add(new ReorderElementwiseUnary); |
89 | transformations->Add(new ReorderReshapeTranspose); |
90 | transformations->Add(new ResolveBatchNormalization); |
91 | transformations->Add(new ResolveConstantBinaryOperator); |
92 | transformations->Add(new ResolveConstantFill); |
93 | transformations->Add(new ResolveConstantGather); |
94 | transformations->Add(new ResolveConstantPack); |
95 | transformations->Add(new ResolveConstantRandomUniform); |
96 | transformations->Add(new ResolveConstantRange); |
97 | transformations->Add(new ResolveConstantReshape); |
98 | transformations->Add(new ResolveConstantSelect); |
99 | transformations->Add(new ResolveConstantSlice); |
100 | transformations->Add(new ResolveConstantStridedSlice); |
101 | transformations->Add(new ResolveConstantTile); |
102 | transformations->Add(new ResolveConstantTranspose); |
103 | transformations->Add(new ResolveConstantUnaryOperator); |
104 | transformations->Add(new ResolveTensorFlowMerge); |
105 | transformations->Add(new ResolveSqueezeAttributes); |
106 | transformations->Add(new ResolveTensorFlowSwitch); |
107 | transformations->Add(new ResolveTensorFlowConcat); |
108 | transformations->Add(new ResolveMultiplyByZero); |
109 | transformations->Add(new IdentifyHardSwish); |
110 | transformations->Add(new IdentifyL2Normalization); |
111 | transformations->Add(new IdentifyL2Pool); |
112 | transformations->Add(new IdentifyRelu1); |
113 | transformations->Add(new IdentifyPRelu); |
114 | transformations->Add(new RemoveTrivialBinaryOperator); |
115 | transformations->Add(new ResolveFakeQuantArgsFromVars); |
116 | transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant); |
117 | transformations->Add(new ResolveSpaceToBatchNDAttributes); |
118 | transformations->Add(new ResolveBatchToSpaceNDAttributes); |
119 | transformations->Add(new ResolvePadAttributes); |
120 | transformations->Add(new ResolvePadV2Attributes); |
121 | transformations->Add(new ResolveStridedSliceAttributes); |
122 | transformations->Add(new ResolveSliceAttributes); |
123 | transformations->Add(new ResolveReduceAttributes); |
124 | transformations->Add(new ResolveConstantShapeOrRank); |
125 | transformations->Add(new MakeInitialDequantizeOperator); |
126 | transformations->Add(new UnpartitionEmbeddingLookup); |
127 | transformations->Add(new ResolveGatherAttributes); |
128 | } |
129 | |
130 | bool SupportsQuantization(FileFormat format) { |
131 | return (format == GRAPHVIZ_DOT || format == TFLITE); |
132 | } |
133 | |
134 | bool SupportsFusedActivationFunction(FileFormat format) { |
135 | return (format == GRAPHVIZ_DOT || format == TFLITE); |
136 | } |
137 | |
138 | bool SupportsLstmCell(FileFormat format) { |
139 | return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT || |
140 | format == TFLITE); |
141 | } |
142 | |
143 | bool SupportsPreallocatedWorkspace(FileFormat format) { |
144 | return (format == TFLITE); |
145 | } |
146 | |
147 | bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; } |
148 | |
149 | bool IsRealValued(toco::ArrayDataType type) { |
150 | // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used |
151 | // for quantized real-number values, and no other integer type is ever used |
152 | // for that. This is dirty, should be resolved as part of a more general push |
153 | // to more explicitly distinguish between true-integers and |
154 | // integers used as quantized values representing real numbers. |
155 | return static_cast<bool>(type == toco::ArrayDataType::kFloat || |
156 | type == toco::ArrayDataType::kUint8 || |
157 | type == toco::ArrayDataType::kInt16); |
158 | } |
159 | |
160 | void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) { |
161 | const FileFormat output_format = toco_flags.output_format(); |
162 | ArrayDataType type; |
163 | if (!SupportsQuantization(output_format)) { |
164 | // Data type is implicitly float for non-quantized formats |
165 | type = ArrayDataType::kFloat; |
166 | } else if (toco_flags.has_inference_input_type()) { |
167 | type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type()); |
168 | } else if (toco_flags.has_inference_type()) { |
169 | type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type()); |
170 | } else { |
171 | // Nothing to do. Data types stay as-is. |
172 | return; |
173 | } |
174 | |
175 | for (int i = 0; i < model->flags.input_arrays_size(); i++) { |
176 | std::string const& array_name = model->flags.input_arrays(i).name(); |
177 | auto* array = &model->GetArray(array_name); |
178 | // Note that the notion of changing data types only applies to real-numbers |
179 | // arrays (see the documentation for inference_input_type). |
180 | // TODO(benoitjacob) this is assuming that uint8 arrays are quantized, |
181 | // i.e. represent real numbers by means of quantization parameters, |
182 | // and not plain integer uint8 input arrays. |
183 | if (!IsRealValued(array->data_type)) { |
184 | // Ignore non-real data types. |
185 | continue; |
186 | } |
187 | // The enum value QUANTIZED_UINT8 for --inference_type and |
188 | // --inference_input_type has long meant just 'QUANTIZED', being used as |
189 | // well in mixed 8-bit / 16-bit quantized models. However, |
190 | // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit, |
191 | // and people have run into issues in the situation where they have an |
192 | // already mixed 8-bit / 16-bit quantized model in TFLITE format and |
193 | // want to run it again through toco, without having to re-specify all the |
194 | // extra array info that was used in the (complicated) process of initially |
195 | // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8 |
196 | // just work in that case, we implement the logic that when an array is |
197 | // already quantized, if --inference_type is quantized (so we're not |
198 | // asking to dequantize here), no change of quantized data type is to be |
199 | // recorded. |
200 | if (array->data_type != toco::ArrayDataType::kFloat && |
201 | type != toco::ArrayDataType::kFloat) { |
202 | continue; |
203 | } |
204 | |
205 | array->final_data_type = type; |
206 | } |
207 | } |
208 | |
209 | } // namespace |
210 | |
211 | std::unique_ptr<Model> Import(const TocoFlags& toco_flags, |
212 | const ModelFlags& model_flags, |
213 | const std::string& input_file_contents) { |
214 | std::unique_ptr<Model> model; |
215 | switch (toco_flags.input_format()) { |
216 | case TENSORFLOW_GRAPHDEF: { |
217 | TensorFlowImportFlags tf_import_flags; |
218 | tf_import_flags.drop_control_dependency = |
219 | toco_flags.has_drop_control_dependency() |
220 | ? toco_flags.drop_control_dependency() |
221 | : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); |
222 | |
223 | tf_import_flags.import_all_ops_as_unsupported = |
224 | toco_flags.force_select_tf_ops(); |
225 | |
226 | model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, |
227 | input_file_contents); |
228 | break; |
229 | } |
230 | case TFLITE: |
231 | model = toco::tflite::Import(model_flags, input_file_contents); |
232 | ResolveModelFlags(model_flags, model.get()); |
233 | CheckInvariants(*model); |
234 | break; |
235 | default: |
236 | LOG(FATAL) << "Unhandled input_format='" |
237 | << FileFormat_Name(toco_flags.input_format()) << "'" ; |
238 | } |
239 | |
240 | LogDump(kLogLevelModelChanged, "AT IMPORT" , *model); |
241 | |
242 | return model; |
243 | } |
244 | |
245 | tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, |
246 | Model* model) { |
247 | const FileFormat output_format = toco_flags.output_format(); |
248 | const IODataType inference_type = toco_flags.inference_type(); |
249 | |
250 | const bool quantize_output = |
251 | SupportsQuantization(output_format) && |
252 | (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16); |
253 | |
254 | if (quantize_output) { |
255 | QCHECK_NE(toco_flags.inference_input_type(), FLOAT) |
256 | << "Quantized inference is not allowed with float inputs." ; |
257 | } |
258 | |
259 | // Clean up after import. |
260 | SetFinalDataTypeOnInputs(toco_flags, model); |
261 | UseArraysExtraInfo(model, quantize_output); |
262 | FinishBuildingRNNStates(model); |
263 | |
264 | // Remove unused ops before performing any other optimizations. This is to |
265 | // stop optimizations from crossing the input/output boundaries. For example |
266 | // this will stop BatchNorm fusing if the output node is in between a conv |
267 | // and BatchNorm layers. |
268 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
269 | model, "Removing unused ops" , {new toco::RemoveUnusedOp})); |
270 | |
271 | GraphTransformationsSet transformations; |
272 | MakeGeneralGraphTransformationsSet(&transformations); |
273 | auto* remove_trivial_reshape = new RemoveTrivialReshape; |
274 | transformations.Add(remove_trivial_reshape); |
275 | auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant; |
276 | if (quantize_output) { |
277 | resolve_constant_fake_quant->set_propagate_fake_quant_num_bits( |
278 | toco_flags.propagate_fake_quant_num_bits()); |
279 | } |
280 | transformations.Add(resolve_constant_fake_quant); |
281 | if (SupportsFusedActivationFunction(output_format)) { |
282 | transformations.Add(new FuseActivationFunctions); |
283 | } else { |
284 | transformations.Add(new UnfuseActivationFunctions); |
285 | } |
286 | if (toco_flags.drop_fake_quant()) { |
287 | transformations.Add(new DropFakeQuant); |
288 | } else { |
289 | // See the doc for --reorder_across_fake_quant: that flag is needed to |
290 | // support some existing models, e.g. WordLens, that have FakeQuant |
291 | // nodes in the wrong places. |
292 | // TODO(benoitjacob): drop special casing when we can. |
293 | if ((quantize_output && toco_flags.reorder_across_fake_quant())) { |
294 | transformations.Add(new DropFakeQuant); |
295 | } |
296 | } |
297 | transformations.Add(new ConvertPureConvToDepthwise); |
298 | if (SupportsLstmCell(output_format)) { |
299 | if (!toco_flags.debug_disable_recurrent_cell_fusion()) { |
300 | transformations.Add(new IdentifyLstmCell); |
301 | } |
302 | if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) { |
303 | transformations.Add(new toco::SplitLstmCellInputs); |
304 | } else { |
305 | transformations.Add(new toco::MergeLstmCellInputs); |
306 | } |
307 | } |
308 | transformations.Add(new ResolveConstantConcatenation); |
309 | // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise |
310 | // conv, so we need to make sure we don't convert to dilated depthwise conv |
311 | // when outputing to TF GraphDef. |
312 | auto* identify_dilated_conv = new IdentifyDilatedConv; |
313 | if (output_format == TENSORFLOW_GRAPHDEF) { |
314 | identify_dilated_conv->set_identify_depthwise_conv(false); |
315 | } |
316 | transformations.Add(identify_dilated_conv); |
317 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
318 | model, "general graph transformations" , transformations)); |
319 | |
320 | if (quantize_output) { |
321 | if (toco_flags.propagate_fake_quant_num_bits()) { |
322 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
323 | model, "fake quant propagation graph transformations" , |
324 | {new PropagateFakeQuantNumBits})); |
325 | } |
326 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
327 | model, "pre-quantization graph transformations" , |
328 | { |
329 | new HardcodeMinMax, |
330 | new DropFakeQuant, |
331 | })); |
332 | } |
333 | |
334 | // Try to merge bidirectional sequence lstm or rnn if present. |
335 | GraphTransformationsSet bidirectional_transformations; |
336 | bidirectional_transformations.Add(new RemoveUnusedOp); |
337 | bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm); |
338 | bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn); |
339 | bidirectional_transformations.Add( |
340 | new toco::GroupDynamicBidirectionalSequenceRnn); |
341 | bidirectional_transformations.Add( |
342 | new toco::GroupDynamicBidirectionalSequenceLstm); |
343 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
344 | model, "Group bidirectional sequence lstm/rnn" , |
345 | bidirectional_transformations)); |
346 | |
347 | // Fix any issues with IO edges. This must happen after any transform that |
348 | // may modify the structure of the edges. |
349 | FixEdgeArrays(model); |
350 | FixOperatorOrdering(model); |
351 | |
352 | if (quantize_output) { |
353 | // If the user specified default min/max ranges we need to set all arrays |
354 | // that didn't either have a min/max specified or get one set via |
355 | // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running |
356 | // HardcodeMinMax to move changes through the graph as we make changes. |
357 | auto propagate_default_min_max = std::make_unique<PropagateDefaultMinMax>(); |
358 | bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() && |
359 | toco_flags.has_default_ranges_max()); |
360 | if (has_default_ranges_flag) { |
361 | propagate_default_min_max->DefineTypeRange( |
362 | ArrayDataType::kUint8, toco_flags.default_ranges_min(), |
363 | toco_flags.default_ranges_max()); |
364 | } |
365 | if (toco_flags.has_default_int16_ranges_min() && |
366 | toco_flags.has_default_int16_ranges_max()) { |
367 | propagate_default_min_max->DefineTypeRange( |
368 | ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(), |
369 | toco_flags.default_int16_ranges_max()); |
370 | } |
371 | if (propagate_default_min_max->has_any_ranges_defined()) { |
372 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
373 | model, "default min-max range propagation graph transformations" , |
374 | { |
375 | propagate_default_min_max.release(), |
376 | new HardcodeMinMax, |
377 | })); |
378 | } |
379 | |
380 | CheckIsReadyForQuantization(*model); |
381 | auto* ensure_safe_for_int8_kernels = |
382 | new EnsureUint8WeightsSafeForFastInt8Kernels; |
383 | ensure_safe_for_int8_kernels->set_allow_nudging_weights( |
384 | toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel()); |
385 | ensure_safe_for_int8_kernels->set_has_default_ranges_flag( |
386 | has_default_ranges_flag); |
387 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
388 | model, "quantization graph transformations" , |
389 | { |
390 | new RemoveTrivialQuantizedActivationFunc, |
391 | new RemoveTrivialQuantizedMinMax, |
392 | new Quantize, |
393 | new RemoveFinalDequantizeOp, |
394 | ensure_safe_for_int8_kernels, |
395 | })); |
396 | if (SupportsShuffledFCWeights(output_format)) { |
397 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
398 | model, "shuffling of FC weights" , {new ShuffleFCWeights})); |
399 | } |
400 | } else { |
401 | GraphTransformationsSet dequantization_transformations{new Dequantize}; |
402 | // Dequantize creates FakeQuant nodes. We may want to discard |
403 | // those immediately. |
404 | if (toco_flags.drop_fake_quant()) { |
405 | dequantization_transformations.Add(new DropFakeQuant); |
406 | } |
407 | |
408 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
409 | model, "dequantization graph transformations" , |
410 | dequantization_transformations)); |
411 | } |
412 | |
413 | // It's actually unfortunate we have to put the graph transformation here: |
414 | // If user choose to use broadcast mul to do nearset neighbor upsampling. That |
415 | // is: |
416 | // Input [1, 20, 1, 20, 1, 64] * ones [1, 3, 1, 3, 1, 1] |
417 | // The problem is if the input is quantized, then the quantization parameters |
418 | // will be slightly different for the input and the output. (although the |
419 | // difference is really small). |
420 | // But, since we're changing this pattern to be pack-based which enforce |
421 | // the quantization parameters to be exactly the same. |
422 | // So we have to wait for all quantization parameters being resolved and |
423 | // propagated and create our own. |
424 | // We may need to revisit this logic later. |
425 | GraphTransformationsSet nearest_upsample_transformations; |
426 | nearest_upsample_transformations.Add(new toco::IdentifyNearestUpsample); |
427 | TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus( |
428 | model, "Identify nearest upsample." , nearest_upsample_transformations)); |
429 | |
430 | if (output_format == TENSORFLOW_GRAPHDEF) { |
431 | EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); |
432 | } |
433 | |
434 | // Deduplicate large constant arrays. |
435 | DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes()); |
436 | |
437 | LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS" , *model); |
438 | |
439 | if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { |
440 | // By now there shouldn't be any unsupported ops when exporting to |
441 | // TensorFlow GraphDef. |
442 | CheckUnsupportedOperations(*model); |
443 | } |
444 | |
445 | if (SupportsPreallocatedWorkspace(output_format)) { |
446 | AllocateTransientArrays(model, kDefaultTransientDataAlignment); |
447 | LogDump(kLogLevelModelChanged, "AFTER ALLOCATION" , *model); |
448 | } |
449 | |
450 | CheckModelCounts(*model); |
451 | CheckFinalDataTypesSatisfied(*model); |
452 | |
453 | // Estimate and log the number of arithmetic ops |
454 | int64_t ops_count = 0; |
455 | if (EstimateArithmeticOpsCount(*model, &ops_count)) { |
456 | LOG(INFO) << "Estimated count of arithmetic ops: " << ops_count |
457 | << " ops, equivalently " << ops_count / 2 << " MACs" ; |
458 | } |
459 | model->ops_count = ops_count; |
460 | int64_t params_count = 0; |
461 | |
462 | // Compute and log the number of parameters |
463 | for (const auto& array_pair : model->GetArrayMap()) { |
464 | const Array& array = *array_pair.second; |
465 | if (!array.buffer) { |
466 | // not a parameter array |
467 | continue; |
468 | } |
469 | params_count += RequiredBufferSizeForShape(array.shape()); |
470 | } |
471 | LOG(INFO) << "Number of parameters: " << params_count; |
472 | return ::tensorflow::OkStatus(); |
473 | } |
474 | |
475 | tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, |
476 | bool allow_custom_ops, |
477 | std::string* output_file_contents) { |
478 | switch (toco_flags.output_format()) { |
479 | case TENSORFLOW_GRAPHDEF: |
480 | ExportTensorFlowGraphDef(model, output_file_contents); |
481 | break; |
482 | case TFLITE: { |
483 | toco::tflite::ExportParams params; |
484 | |
485 | params.enable_select_tf_ops = |
486 | toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops(); |
487 | params.allow_custom_ops = allow_custom_ops; |
488 | params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors(); |
489 | params.disable_per_channel = |
490 | toco_flags.disable_per_channel_quantization(); |
491 | if (toco_flags.post_training_quantize()) { |
492 | if (toco_flags.quantize_to_float16()) { |
493 | params.quantize_weights = tflite::QuantizedBufferType::FLOAT16; |
494 | } else { |
495 | params.quantize_weights = tflite::QuantizedBufferType::INT8; |
496 | } |
497 | } |
498 | auto status = toco::tflite::Export(model, output_file_contents, params); |
499 | if (!status.ok()) { |
500 | LOG(ERROR) << status.error_message(); |
501 | } |
502 | return status; |
503 | } break; |
504 | case GRAPHVIZ_DOT: |
505 | DumpGraphviz(model, output_file_contents, "Computation Graph" ); |
506 | break; |
507 | default: |
508 | LOG(FATAL) << "Unhandled output_format='" |
509 | << FileFormat_Name(toco_flags.output_format()) << "'" ; |
510 | } |
511 | return tensorflow::Status(); |
512 | } |
513 | |
514 | } // namespace toco |
515 | |