1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/toco_tooling.h"
16
17#include <cstdlib>
18#include <memory>
19#include <set>
20#include <string>
21
22#include "absl/memory/memory.h"
23#include "absl/strings/str_join.h"
24#include "tensorflow/core/platform/logging.h"
25#include "tensorflow/lite/toco/allocate_transient_arrays.h"
26#include "tensorflow/lite/toco/dump_graphviz.h"
27#include "tensorflow/lite/toco/export_tensorflow.h"
28#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
29#include "tensorflow/lite/toco/import_tensorflow.h"
30#include "tensorflow/lite/toco/model.h"
31#include "tensorflow/lite/toco/model_flags.pb.h"
32#include "tensorflow/lite/toco/tflite/export.h"
33#include "tensorflow/lite/toco/tflite/import.h"
34#include "tensorflow/lite/toco/toco_flags.pb.h"
35#include "tensorflow/lite/toco/tooling_util.h"
36
37namespace toco {
38namespace {
39// CHECK-fails if the model contains a kUnsupported operation.
40void CheckUnsupportedOperations(const Model& model) {
41 std::set<std::string> unsupported_ops;
42 for (auto& op : model.operators) {
43 if (op->type == OperatorType::kUnsupported) {
44 unsupported_ops.insert(
45 static_cast<const TensorFlowUnsupportedOperator*>(op.get())
46 ->tensorflow_op);
47 }
48 }
49 QCHECK(unsupported_ops.empty())
50 << "These unsupported ops were not removed by graph transformations: "
51 << absl::StrJoin(unsupported_ops, ", ");
52}
53
54void MakeGeneralGraphTransformationsSet(
55 GraphTransformationsSet* transformations) {
56 CHECK(transformations->empty());
57 transformations->Add(new ConvertExpandDimsToReshape);
58 transformations->Add(new ConvertMatrixDiagV2OrV3ToV1);
59 transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1);
60 transformations->Add(new ConvertSqueezeToReshape);
61 transformations->Add(new ConvertTrivialAddNToAdd);
62 transformations->Add(new ConvertTrivialPackToReshape);
63 transformations->Add(new ConvertTrivialTileToConcat);
64 transformations->Add(new ConvertTrivialTransposeToReshape);
65 transformations->Add(new ConvertReorderAxes);
66 transformations->Add(new ResolveReshapeAttributes);
67 transformations->Add(new ResolveTransposeAttributes);
68 transformations->Add(new PropagateActivationFunctionIntoConstants);
69 transformations->Add(new PropagateArrayDataTypes);
70 transformations->Add(new PropagateFixedSizes);
71 transformations->Add(new RemoveSuccessiveTranspose);
72 transformations->Add(new RemoveTensorFlowAssert);
73 transformations->Add(new RemoveTensorFlowIdentity);
74 transformations->Add(new RemoveTrivialConcatenation);
75 transformations->Add(new RemoveTrivialConcatenationInput);
76 transformations->Add(new RemoveTrivialFakeQuant);
77 transformations->Add(new RemoveTrivialSlice);
78 transformations->Add(new RemoveUnusedOp);
79 transformations->Add(new EnsureBiasVectors);
80 transformations->Add(new ResolveReorderAxes);
81 transformations->Add(new UnrollBatchMatMul);
82 transformations->Add(new ResolveTensorFlowMatMul);
83 transformations->Add(new FuseBinaryIntoPrecedingAffine);
84 transformations->Add(new FuseBinaryIntoFollowingAffine);
85 transformations->Add(new FuseBroadcastIntoFollowingBinary);
86 transformations->Add(new MergeReshapeIntoPrecedingTranspose);
87 transformations->Add(new MoveBinaryOperatorBeforeReshape);
88 transformations->Add(new ReorderElementwiseUnary);
89 transformations->Add(new ReorderReshapeTranspose);
90 transformations->Add(new ResolveBatchNormalization);
91 transformations->Add(new ResolveConstantBinaryOperator);
92 transformations->Add(new ResolveConstantFill);
93 transformations->Add(new ResolveConstantGather);
94 transformations->Add(new ResolveConstantPack);
95 transformations->Add(new ResolveConstantRandomUniform);
96 transformations->Add(new ResolveConstantRange);
97 transformations->Add(new ResolveConstantReshape);
98 transformations->Add(new ResolveConstantSelect);
99 transformations->Add(new ResolveConstantSlice);
100 transformations->Add(new ResolveConstantStridedSlice);
101 transformations->Add(new ResolveConstantTile);
102 transformations->Add(new ResolveConstantTranspose);
103 transformations->Add(new ResolveConstantUnaryOperator);
104 transformations->Add(new ResolveTensorFlowMerge);
105 transformations->Add(new ResolveSqueezeAttributes);
106 transformations->Add(new ResolveTensorFlowSwitch);
107 transformations->Add(new ResolveTensorFlowConcat);
108 transformations->Add(new ResolveMultiplyByZero);
109 transformations->Add(new IdentifyHardSwish);
110 transformations->Add(new IdentifyL2Normalization);
111 transformations->Add(new IdentifyL2Pool);
112 transformations->Add(new IdentifyRelu1);
113 transformations->Add(new IdentifyPRelu);
114 transformations->Add(new RemoveTrivialBinaryOperator);
115 transformations->Add(new ResolveFakeQuantArgsFromVars);
116 transformations->Add(new ReadArrayMinmaxAndNarrowRangeFromFakeQuant);
117 transformations->Add(new ResolveSpaceToBatchNDAttributes);
118 transformations->Add(new ResolveBatchToSpaceNDAttributes);
119 transformations->Add(new ResolvePadAttributes);
120 transformations->Add(new ResolvePadV2Attributes);
121 transformations->Add(new ResolveStridedSliceAttributes);
122 transformations->Add(new ResolveSliceAttributes);
123 transformations->Add(new ResolveReduceAttributes);
124 transformations->Add(new ResolveConstantShapeOrRank);
125 transformations->Add(new MakeInitialDequantizeOperator);
126 transformations->Add(new UnpartitionEmbeddingLookup);
127 transformations->Add(new ResolveGatherAttributes);
128}
129
130bool SupportsQuantization(FileFormat format) {
131 return (format == GRAPHVIZ_DOT || format == TFLITE);
132}
133
134bool SupportsFusedActivationFunction(FileFormat format) {
135 return (format == GRAPHVIZ_DOT || format == TFLITE);
136}
137
138bool SupportsLstmCell(FileFormat format) {
139 return (format == TENSORFLOW_GRAPHDEF || format == GRAPHVIZ_DOT ||
140 format == TFLITE);
141}
142
143bool SupportsPreallocatedWorkspace(FileFormat format) {
144 return (format == TFLITE);
145}
146
147bool SupportsShuffledFCWeights(FileFormat format) { return format == TFLITE; }
148
149bool IsRealValued(toco::ArrayDataType type) {
150 // TODO(benoitjacob) - this is hardcoding that uint8 and int16 are only used
151 // for quantized real-number values, and no other integer type is ever used
152 // for that. This is dirty, should be resolved as part of a more general push
153 // to more explicitly distinguish between true-integers and
154 // integers used as quantized values representing real numbers.
155 return static_cast<bool>(type == toco::ArrayDataType::kFloat ||
156 type == toco::ArrayDataType::kUint8 ||
157 type == toco::ArrayDataType::kInt16);
158}
159
160void SetFinalDataTypeOnInputs(const TocoFlags& toco_flags, Model* model) {
161 const FileFormat output_format = toco_flags.output_format();
162 ArrayDataType type;
163 if (!SupportsQuantization(output_format)) {
164 // Data type is implicitly float for non-quantized formats
165 type = ArrayDataType::kFloat;
166 } else if (toco_flags.has_inference_input_type()) {
167 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_input_type());
168 } else if (toco_flags.has_inference_type()) {
169 type = ConvertIODataTypeToArrayDataType(toco_flags.inference_type());
170 } else {
171 // Nothing to do. Data types stay as-is.
172 return;
173 }
174
175 for (int i = 0; i < model->flags.input_arrays_size(); i++) {
176 std::string const& array_name = model->flags.input_arrays(i).name();
177 auto* array = &model->GetArray(array_name);
178 // Note that the notion of changing data types only applies to real-numbers
179 // arrays (see the documentation for inference_input_type).
180 // TODO(benoitjacob) this is assuming that uint8 arrays are quantized,
181 // i.e. represent real numbers by means of quantization parameters,
182 // and not plain integer uint8 input arrays.
183 if (!IsRealValued(array->data_type)) {
184 // Ignore non-real data types.
185 continue;
186 }
187 // The enum value QUANTIZED_UINT8 for --inference_type and
188 // --inference_input_type has long meant just 'QUANTIZED', being used as
189 // well in mixed 8-bit / 16-bit quantized models. However,
190 // ConvertIODataTypeToArrayDataType still interpretes it as meaning 8bit,
191 // and people have run into issues in the situation where they have an
192 // already mixed 8-bit / 16-bit quantized model in TFLITE format and
193 // want to run it again through toco, without having to re-specify all the
194 // extra array info that was used in the (complicated) process of initially
195 // quantizing that model. In order to have --inference_type=QUANTIZED_UINT8
196 // just work in that case, we implement the logic that when an array is
197 // already quantized, if --inference_type is quantized (so we're not
198 // asking to dequantize here), no change of quantized data type is to be
199 // recorded.
200 if (array->data_type != toco::ArrayDataType::kFloat &&
201 type != toco::ArrayDataType::kFloat) {
202 continue;
203 }
204
205 array->final_data_type = type;
206 }
207}
208
209} // namespace
210
211std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
212 const ModelFlags& model_flags,
213 const std::string& input_file_contents) {
214 std::unique_ptr<Model> model;
215 switch (toco_flags.input_format()) {
216 case TENSORFLOW_GRAPHDEF: {
217 TensorFlowImportFlags tf_import_flags;
218 tf_import_flags.drop_control_dependency =
219 toco_flags.has_drop_control_dependency()
220 ? toco_flags.drop_control_dependency()
221 : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
222
223 tf_import_flags.import_all_ops_as_unsupported =
224 toco_flags.force_select_tf_ops();
225
226 model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
227 input_file_contents);
228 break;
229 }
230 case TFLITE:
231 model = toco::tflite::Import(model_flags, input_file_contents);
232 ResolveModelFlags(model_flags, model.get());
233 CheckInvariants(*model);
234 break;
235 default:
236 LOG(FATAL) << "Unhandled input_format='"
237 << FileFormat_Name(toco_flags.input_format()) << "'";
238 }
239
240 LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
241
242 return model;
243}
244
245tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
246 Model* model) {
247 const FileFormat output_format = toco_flags.output_format();
248 const IODataType inference_type = toco_flags.inference_type();
249
250 const bool quantize_output =
251 SupportsQuantization(output_format) &&
252 (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16);
253
254 if (quantize_output) {
255 QCHECK_NE(toco_flags.inference_input_type(), FLOAT)
256 << "Quantized inference is not allowed with float inputs.";
257 }
258
259 // Clean up after import.
260 SetFinalDataTypeOnInputs(toco_flags, model);
261 UseArraysExtraInfo(model, quantize_output);
262 FinishBuildingRNNStates(model);
263
264 // Remove unused ops before performing any other optimizations. This is to
265 // stop optimizations from crossing the input/output boundaries. For example
266 // this will stop BatchNorm fusing if the output node is in between a conv
267 // and BatchNorm layers.
268 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
269 model, "Removing unused ops", {new toco::RemoveUnusedOp}));
270
271 GraphTransformationsSet transformations;
272 MakeGeneralGraphTransformationsSet(&transformations);
273 auto* remove_trivial_reshape = new RemoveTrivialReshape;
274 transformations.Add(remove_trivial_reshape);
275 auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
276 if (quantize_output) {
277 resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
278 toco_flags.propagate_fake_quant_num_bits());
279 }
280 transformations.Add(resolve_constant_fake_quant);
281 if (SupportsFusedActivationFunction(output_format)) {
282 transformations.Add(new FuseActivationFunctions);
283 } else {
284 transformations.Add(new UnfuseActivationFunctions);
285 }
286 if (toco_flags.drop_fake_quant()) {
287 transformations.Add(new DropFakeQuant);
288 } else {
289 // See the doc for --reorder_across_fake_quant: that flag is needed to
290 // support some existing models, e.g. WordLens, that have FakeQuant
291 // nodes in the wrong places.
292 // TODO(benoitjacob): drop special casing when we can.
293 if ((quantize_output && toco_flags.reorder_across_fake_quant())) {
294 transformations.Add(new DropFakeQuant);
295 }
296 }
297 transformations.Add(new ConvertPureConvToDepthwise);
298 if (SupportsLstmCell(output_format)) {
299 if (!toco_flags.debug_disable_recurrent_cell_fusion()) {
300 transformations.Add(new IdentifyLstmCell);
301 }
302 if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) {
303 transformations.Add(new toco::SplitLstmCellInputs);
304 } else {
305 transformations.Add(new toco::MergeLstmCellInputs);
306 }
307 }
308 transformations.Add(new ResolveConstantConcatenation);
309 // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
310 // conv, so we need to make sure we don't convert to dilated depthwise conv
311 // when outputing to TF GraphDef.
312 auto* identify_dilated_conv = new IdentifyDilatedConv;
313 if (output_format == TENSORFLOW_GRAPHDEF) {
314 identify_dilated_conv->set_identify_depthwise_conv(false);
315 }
316 transformations.Add(identify_dilated_conv);
317 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
318 model, "general graph transformations", transformations));
319
320 if (quantize_output) {
321 if (toco_flags.propagate_fake_quant_num_bits()) {
322 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
323 model, "fake quant propagation graph transformations",
324 {new PropagateFakeQuantNumBits}));
325 }
326 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
327 model, "pre-quantization graph transformations",
328 {
329 new HardcodeMinMax,
330 new DropFakeQuant,
331 }));
332 }
333
334 // Try to merge bidirectional sequence lstm or rnn if present.
335 GraphTransformationsSet bidirectional_transformations;
336 bidirectional_transformations.Add(new RemoveUnusedOp);
337 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
338 bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
339 bidirectional_transformations.Add(
340 new toco::GroupDynamicBidirectionalSequenceRnn);
341 bidirectional_transformations.Add(
342 new toco::GroupDynamicBidirectionalSequenceLstm);
343 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
344 model, "Group bidirectional sequence lstm/rnn",
345 bidirectional_transformations));
346
347 // Fix any issues with IO edges. This must happen after any transform that
348 // may modify the structure of the edges.
349 FixEdgeArrays(model);
350 FixOperatorOrdering(model);
351
352 if (quantize_output) {
353 // If the user specified default min/max ranges we need to set all arrays
354 // that didn't either have a min/max specified or get one set via
355 // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
356 // HardcodeMinMax to move changes through the graph as we make changes.
357 auto propagate_default_min_max = std::make_unique<PropagateDefaultMinMax>();
358 bool has_default_ranges_flag = (toco_flags.has_default_ranges_min() &&
359 toco_flags.has_default_ranges_max());
360 if (has_default_ranges_flag) {
361 propagate_default_min_max->DefineTypeRange(
362 ArrayDataType::kUint8, toco_flags.default_ranges_min(),
363 toco_flags.default_ranges_max());
364 }
365 if (toco_flags.has_default_int16_ranges_min() &&
366 toco_flags.has_default_int16_ranges_max()) {
367 propagate_default_min_max->DefineTypeRange(
368 ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
369 toco_flags.default_int16_ranges_max());
370 }
371 if (propagate_default_min_max->has_any_ranges_defined()) {
372 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
373 model, "default min-max range propagation graph transformations",
374 {
375 propagate_default_min_max.release(),
376 new HardcodeMinMax,
377 }));
378 }
379
380 CheckIsReadyForQuantization(*model);
381 auto* ensure_safe_for_int8_kernels =
382 new EnsureUint8WeightsSafeForFastInt8Kernels;
383 ensure_safe_for_int8_kernels->set_allow_nudging_weights(
384 toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
385 ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
386 has_default_ranges_flag);
387 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
388 model, "quantization graph transformations",
389 {
390 new RemoveTrivialQuantizedActivationFunc,
391 new RemoveTrivialQuantizedMinMax,
392 new Quantize,
393 new RemoveFinalDequantizeOp,
394 ensure_safe_for_int8_kernels,
395 }));
396 if (SupportsShuffledFCWeights(output_format)) {
397 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
398 model, "shuffling of FC weights", {new ShuffleFCWeights}));
399 }
400 } else {
401 GraphTransformationsSet dequantization_transformations{new Dequantize};
402 // Dequantize creates FakeQuant nodes. We may want to discard
403 // those immediately.
404 if (toco_flags.drop_fake_quant()) {
405 dequantization_transformations.Add(new DropFakeQuant);
406 }
407
408 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
409 model, "dequantization graph transformations",
410 dequantization_transformations));
411 }
412
413 // It's actually unfortunate we have to put the graph transformation here:
414 // If user choose to use broadcast mul to do nearset neighbor upsampling. That
415 // is:
416 // Input [1, 20, 1, 20, 1, 64] * ones [1, 3, 1, 3, 1, 1]
417 // The problem is if the input is quantized, then the quantization parameters
418 // will be slightly different for the input and the output. (although the
419 // difference is really small).
420 // But, since we're changing this pattern to be pack-based which enforce
421 // the quantization parameters to be exactly the same.
422 // So we have to wait for all quantization parameters being resolved and
423 // propagated and create our own.
424 // We may need to revisit this logic later.
425 GraphTransformationsSet nearest_upsample_transformations;
426 nearest_upsample_transformations.Add(new toco::IdentifyNearestUpsample);
427 TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
428 model, "Identify nearest upsample.", nearest_upsample_transformations));
429
430 if (output_format == TENSORFLOW_GRAPHDEF) {
431 EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
432 }
433
434 // Deduplicate large constant arrays.
435 DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
436
437 LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
438
439 if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
440 // By now there shouldn't be any unsupported ops when exporting to
441 // TensorFlow GraphDef.
442 CheckUnsupportedOperations(*model);
443 }
444
445 if (SupportsPreallocatedWorkspace(output_format)) {
446 AllocateTransientArrays(model, kDefaultTransientDataAlignment);
447 LogDump(kLogLevelModelChanged, "AFTER ALLOCATION", *model);
448 }
449
450 CheckModelCounts(*model);
451 CheckFinalDataTypesSatisfied(*model);
452
453 // Estimate and log the number of arithmetic ops
454 int64_t ops_count = 0;
455 if (EstimateArithmeticOpsCount(*model, &ops_count)) {
456 LOG(INFO) << "Estimated count of arithmetic ops: " << ops_count
457 << " ops, equivalently " << ops_count / 2 << " MACs";
458 }
459 model->ops_count = ops_count;
460 int64_t params_count = 0;
461
462 // Compute and log the number of parameters
463 for (const auto& array_pair : model->GetArrayMap()) {
464 const Array& array = *array_pair.second;
465 if (!array.buffer) {
466 // not a parameter array
467 continue;
468 }
469 params_count += RequiredBufferSizeForShape(array.shape());
470 }
471 LOG(INFO) << "Number of parameters: " << params_count;
472 return ::tensorflow::OkStatus();
473}
474
475tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
476 bool allow_custom_ops,
477 std::string* output_file_contents) {
478 switch (toco_flags.output_format()) {
479 case TENSORFLOW_GRAPHDEF:
480 ExportTensorFlowGraphDef(model, output_file_contents);
481 break;
482 case TFLITE: {
483 toco::tflite::ExportParams params;
484
485 params.enable_select_tf_ops =
486 toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops();
487 params.allow_custom_ops = allow_custom_ops;
488 params.allow_dynamic_tensors = toco_flags.allow_dynamic_tensors();
489 params.disable_per_channel =
490 toco_flags.disable_per_channel_quantization();
491 if (toco_flags.post_training_quantize()) {
492 if (toco_flags.quantize_to_float16()) {
493 params.quantize_weights = tflite::QuantizedBufferType::FLOAT16;
494 } else {
495 params.quantize_weights = tflite::QuantizedBufferType::INT8;
496 }
497 }
498 auto status = toco::tflite::Export(model, output_file_contents, params);
499 if (!status.ok()) {
500 LOG(ERROR) << status.error_message();
501 }
502 return status;
503 } break;
504 case GRAPHVIZ_DOT:
505 DumpGraphviz(model, output_file_contents, "Computation Graph");
506 break;
507 default:
508 LOG(FATAL) << "Unhandled output_format='"
509 << FileFormat_Name(toco_flags.output_format()) << "'";
510 }
511 return tensorflow::Status();
512}
513
514} // namespace toco
515