1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/toco/toco_cmdline_flags.h" |
17 | |
18 | #include <optional> |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "absl/strings/numbers.h" |
23 | #include "absl/strings/str_join.h" |
24 | #include "absl/strings/str_split.h" |
25 | #include "absl/strings/strip.h" |
26 | #include "absl/types/optional.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/util/command_line_flags.h" |
29 | #include "tensorflow/lite/toco/toco_port.h" |
30 | |
31 | namespace toco { |
32 | |
33 | bool ParseTocoFlagsFromCommandLineFlags( |
34 | int* argc, char* argv[], std::string* msg, |
35 | ParsedTocoFlags* parsed_toco_flags_ptr) { |
36 | using tensorflow::Flag; |
37 | ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr; |
38 | std::vector<tensorflow::Flag> flags = { |
39 | Flag("input_file" , parsed_flags.input_file.bind(), |
40 | parsed_flags.input_file.default_value(), |
41 | "Input file (model of any supported format). For Protobuf " |
42 | "formats, both text and binary are supported regardless of file " |
43 | "extension." ), |
44 | Flag("savedmodel_directory" , parsed_flags.savedmodel_directory.bind(), |
45 | parsed_flags.savedmodel_directory.default_value(), |
46 | "Deprecated. Full path to the directory containing the SavedModel." ), |
47 | Flag("output_file" , parsed_flags.output_file.bind(), |
48 | parsed_flags.output_file.default_value(), |
49 | "Output file. " |
50 | "For Protobuf formats, the binary format will be used." ), |
51 | Flag("input_format" , parsed_flags.input_format.bind(), |
52 | parsed_flags.input_format.default_value(), |
53 | "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE." ), |
54 | Flag("output_format" , parsed_flags.output_format.bind(), |
55 | parsed_flags.output_format.default_value(), |
56 | "Output file format. " |
57 | "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT." ), |
58 | Flag("savedmodel_tagset" , parsed_flags.savedmodel_tagset.bind(), |
59 | parsed_flags.savedmodel_tagset.default_value(), |
60 | "Deprecated. Comma-separated set of tags identifying the " |
61 | "MetaGraphDef within the SavedModel to analyze. All tags in the tag " |
62 | "set must be specified." ), |
63 | Flag("default_ranges_min" , parsed_flags.default_ranges_min.bind(), |
64 | parsed_flags.default_ranges_min.default_value(), |
65 | "If defined, will be used as the default value for the min bound " |
66 | "of min/max ranges used for quantization of uint8 arrays." ), |
67 | Flag("default_ranges_max" , parsed_flags.default_ranges_max.bind(), |
68 | parsed_flags.default_ranges_max.default_value(), |
69 | "If defined, will be used as the default value for the max bound " |
70 | "of min/max ranges used for quantization of uint8 arrays." ), |
71 | Flag("default_int16_ranges_min" , |
72 | parsed_flags.default_int16_ranges_min.bind(), |
73 | parsed_flags.default_int16_ranges_min.default_value(), |
74 | "If defined, will be used as the default value for the min bound " |
75 | "of min/max ranges used for quantization of int16 arrays." ), |
76 | Flag("default_int16_ranges_max" , |
77 | parsed_flags.default_int16_ranges_max.bind(), |
78 | parsed_flags.default_int16_ranges_max.default_value(), |
79 | "If defined, will be used as the default value for the max bound " |
80 | "of min/max ranges used for quantization of int16 arrays." ), |
81 | Flag("inference_type" , parsed_flags.inference_type.bind(), |
82 | parsed_flags.inference_type.default_value(), |
83 | "Target data type of arrays in the output file (for input_arrays, " |
84 | "this may be overridden by inference_input_type). " |
85 | "One of FLOAT, QUANTIZED_UINT8." ), |
86 | Flag("inference_input_type" , parsed_flags.inference_input_type.bind(), |
87 | parsed_flags.inference_input_type.default_value(), |
88 | "Target data type of input arrays. " |
89 | "If not specified, inference_type is used. " |
90 | "One of FLOAT, QUANTIZED_UINT8." ), |
91 | Flag("input_type" , parsed_flags.input_type.bind(), |
92 | parsed_flags.input_type.default_value(), |
93 | "Deprecated ambiguous flag that set both --input_data_types and " |
94 | "--inference_input_type." ), |
95 | Flag("input_types" , parsed_flags.input_types.bind(), |
96 | parsed_flags.input_types.default_value(), |
97 | "Deprecated ambiguous flag that set both --input_data_types and " |
98 | "--inference_input_type. Was meant to be a " |
99 | "comma-separated list, but this was deprecated before " |
100 | "multiple-input-types was ever properly supported." ), |
101 | |
102 | Flag("drop_fake_quant" , parsed_flags.drop_fake_quant.bind(), |
103 | parsed_flags.drop_fake_quant.default_value(), |
104 | "Ignore and discard FakeQuant nodes. For instance, to " |
105 | "generate plain float code without fake-quantization from a " |
106 | "quantized graph." ), |
107 | Flag( |
108 | "reorder_across_fake_quant" , |
109 | parsed_flags.reorder_across_fake_quant.bind(), |
110 | parsed_flags.reorder_across_fake_quant.default_value(), |
111 | "Normally, FakeQuant nodes must be strict boundaries for graph " |
112 | "transformations, in order to ensure that quantized inference has " |
113 | "the exact same arithmetic behavior as quantized training --- which " |
114 | "is the whole point of quantized training and of FakeQuant nodes in " |
115 | "the first place. " |
116 | "However, that entails subtle requirements on where exactly " |
117 | "FakeQuant nodes must be placed in the graph. Some quantized graphs " |
118 | "have FakeQuant nodes at unexpected locations, that prevent graph " |
119 | "transformations that are necessary in order to generate inference " |
120 | "code for these graphs. Such graphs should be fixed, but as a " |
121 | "temporary work-around, setting this reorder_across_fake_quant flag " |
122 | "allows TOCO to perform necessary graph transformaitons on them, " |
123 | "at the cost of no longer faithfully matching inference and training " |
124 | "arithmetic." ), |
125 | Flag("allow_custom_ops" , parsed_flags.allow_custom_ops.bind(), |
126 | parsed_flags.allow_custom_ops.default_value(), |
127 | "If true, allow TOCO to create TF Lite Custom operators for all the " |
128 | "unsupported TensorFlow ops." ), |
129 | Flag("custom_opdefs" , parsed_flags.custom_opdefs.bind(), |
130 | parsed_flags.custom_opdefs.default_value(), |
131 | "List of strings representing custom ops OpDefs that are included " |
132 | "in the GraphDef." ), |
133 | Flag("allow_dynamic_tensors" , parsed_flags.allow_dynamic_tensors.bind(), |
134 | parsed_flags.allow_dynamic_tensors.default_value(), |
135 | "Boolean flag indicating whether the converter should allow models " |
136 | "with dynamic Tensor shape. When set to False, the converter will " |
137 | "generate runtime memory offsets for activation Tensors (with 128 " |
138 | "bits alignment) and error out on models with undetermined Tensor " |
139 | "shape. (Default: True)" ), |
140 | Flag( |
141 | "drop_control_dependency" , |
142 | parsed_flags.drop_control_dependency.bind(), |
143 | parsed_flags.drop_control_dependency.default_value(), |
144 | "If true, ignore control dependency requirements in input TensorFlow " |
145 | "GraphDef. Otherwise an error will be raised upon control dependency " |
146 | "inputs." ), |
147 | Flag("debug_disable_recurrent_cell_fusion" , |
148 | parsed_flags.debug_disable_recurrent_cell_fusion.bind(), |
149 | parsed_flags.debug_disable_recurrent_cell_fusion.default_value(), |
150 | "If true, disable fusion of known identifiable cell subgraphs into " |
151 | "cells. This includes, for example, specific forms of LSTM cell." ), |
152 | Flag("propagate_fake_quant_num_bits" , |
153 | parsed_flags.propagate_fake_quant_num_bits.bind(), |
154 | parsed_flags.propagate_fake_quant_num_bits.default_value(), |
155 | "If true, use FakeQuant* operator num_bits attributes to adjust " |
156 | "array data_types." ), |
157 | Flag("allow_nudging_weights_to_use_fast_gemm_kernel" , |
158 | parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(), |
159 | parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel |
160 | .default_value(), |
161 | "Some fast uint8 GEMM kernels require uint8 weights to avoid the " |
162 | "value 0. This flag allows nudging them to 1 to allow proceeding, " |
163 | "with moderate inaccuracy." ), |
164 | Flag("dedupe_array_min_size_bytes" , |
165 | parsed_flags.dedupe_array_min_size_bytes.bind(), |
166 | parsed_flags.dedupe_array_min_size_bytes.default_value(), |
167 | "Minimum size of constant arrays to deduplicate; arrays smaller " |
168 | "will not be deduplicated." ), |
169 | Flag("split_tflite_lstm_inputs" , |
170 | parsed_flags.split_tflite_lstm_inputs.bind(), |
171 | parsed_flags.split_tflite_lstm_inputs.default_value(), |
172 | "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " |
173 | "Ignored if the output format is not TFLite." ), |
174 | Flag("quantize_to_float16" , parsed_flags.quantize_to_float16.bind(), |
175 | parsed_flags.quantize_to_float16.default_value(), |
176 | "Used in conjunction with post_training_quantize. Specifies that " |
177 | "the weights should be quantized to fp16 instead of the default " |
178 | "(int8)" ), |
179 | Flag("quantize_weights" , parsed_flags.quantize_weights.bind(), |
180 | parsed_flags.quantize_weights.default_value(), |
181 | "Deprecated. Please use --post_training_quantize instead." ), |
182 | Flag("post_training_quantize" , parsed_flags.post_training_quantize.bind(), |
183 | parsed_flags.post_training_quantize.default_value(), |
184 | "Boolean indicating whether to quantize the weights of the " |
185 | "converted float model. Model size will be reduced and there will " |
186 | "be latency improvements (at the cost of accuracy)." ), |
187 | // TODO(b/118822804): Unify the argument definition with `tflite_convert`. |
188 | // WARNING: Experimental interface, subject to change |
189 | Flag("enable_select_tf_ops" , parsed_flags.enable_select_tf_ops.bind(), |
190 | parsed_flags.enable_select_tf_ops.default_value(), "" ), |
191 | // WARNING: Experimental interface, subject to change |
192 | Flag("force_select_tf_ops" , parsed_flags.force_select_tf_ops.bind(), |
193 | parsed_flags.force_select_tf_ops.default_value(), "" ), |
194 | // WARNING: Experimental interface, subject to change |
195 | Flag("unfold_batchmatmul" , parsed_flags.unfold_batchmatmul.bind(), |
196 | parsed_flags.unfold_batchmatmul.default_value(), "" ), |
197 | // WARNING: Experimental interface, subject to change |
198 | Flag("accumulation_type" , parsed_flags.accumulation_type.bind(), |
199 | parsed_flags.accumulation_type.default_value(), |
200 | "Accumulation type to use with quantize_to_float16" ), |
201 | // WARNING: Experimental interface, subject to change |
202 | Flag("allow_bfloat16" , parsed_flags.allow_bfloat16.bind(), |
203 | parsed_flags.allow_bfloat16.default_value(), "" )}; |
204 | |
205 | bool asked_for_help = |
206 | *argc == 2 && (!strcmp(argv[1], "--help" ) || !strcmp(argv[1], "-help" )); |
207 | if (asked_for_help) { |
208 | *msg += tensorflow::Flags::Usage(argv[0], flags); |
209 | return false; |
210 | } else { |
211 | return tensorflow::Flags::Parse(argc, argv, flags); |
212 | } |
213 | } |
214 | |
215 | namespace { |
216 | |
217 | // Defines the requirements for a given flag. kUseDefault means the default |
218 | // should be used in cases where the value isn't specified by the user. |
219 | enum class FlagRequirement { |
220 | kNone, |
221 | kMustBeSpecified, |
222 | kMustNotBeSpecified, |
223 | kUseDefault, |
224 | }; |
225 | |
226 | // Enforces the FlagRequirements are met for a given flag. |
227 | template <typename T> |
228 | void EnforceFlagRequirement(const T& flag, const std::string& flag_name, |
229 | FlagRequirement requirement) { |
230 | if (requirement == FlagRequirement::kMustBeSpecified) { |
231 | QCHECK(flag.specified()) << "Missing required flag " << flag_name; |
232 | } |
233 | if (requirement == FlagRequirement::kMustNotBeSpecified) { |
234 | QCHECK(!flag.specified()) |
235 | << "Given other flags, this flag should not have been specified: " |
236 | << flag_name; |
237 | } |
238 | } |
239 | |
240 | // Gets the value from the flag if specified. Returns default if the |
241 | // FlagRequirement is kUseDefault. |
242 | template <typename T> |
243 | std::optional<T> GetFlagValue(const Arg<T>& flag, FlagRequirement requirement) { |
244 | if (flag.specified()) return flag.value(); |
245 | if (requirement == FlagRequirement::kUseDefault) return flag.default_value(); |
246 | return std::optional<T>(); |
247 | } |
248 | |
249 | } // namespace |
250 | |
251 | void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, |
252 | TocoFlags* toco_flags) { |
253 | namespace port = toco::port; |
254 | port::CheckInitGoogleIsDone("InitGoogle is not done yet" ); |
255 | |
256 | #define READ_TOCO_FLAG(name, requirement) \ |
257 | do { \ |
258 | EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ |
259 | auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ |
260 | if (flag_value.has_value()) { \ |
261 | toco_flags->set_##name(flag_value.value()); \ |
262 | } \ |
263 | } while (false) |
264 | |
265 | #define PARSE_TOCO_FLAG(Type, name, requirement) \ |
266 | do { \ |
267 | EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement); \ |
268 | auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \ |
269 | if (flag_value.has_value()) { \ |
270 | Type x; \ |
271 | QCHECK(Type##_Parse(flag_value.value(), &x)) \ |
272 | << "Unrecognized " << #Type << " value " \ |
273 | << parsed_toco_flags.name.value(); \ |
274 | toco_flags->set_##name(x); \ |
275 | } \ |
276 | } while (false) |
277 | |
278 | PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault); |
279 | PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault); |
280 | PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone); |
281 | PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); |
282 | READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); |
283 | READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone); |
284 | READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone); |
285 | READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone); |
286 | READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); |
287 | READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); |
288 | READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); |
289 | READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone); |
290 | READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone); |
291 | READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone); |
292 | READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel, |
293 | FlagRequirement::kNone); |
294 | READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); |
295 | READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); |
296 | READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); |
297 | READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone); |
298 | READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); |
299 | READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone); |
300 | READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone); |
301 | READ_TOCO_FLAG(unfold_batchmatmul, FlagRequirement::kNone); |
302 | PARSE_TOCO_FLAG(IODataType, accumulation_type, FlagRequirement::kNone); |
303 | READ_TOCO_FLAG(allow_bfloat16, FlagRequirement::kNone); |
304 | |
305 | if (parsed_toco_flags.force_select_tf_ops.value() && |
306 | !parsed_toco_flags.enable_select_tf_ops.value()) { |
307 | // TODO(ycling): Consider to enforce `enable_select_tf_ops` when |
308 | // `force_select_tf_ops` is true. |
309 | LOG(WARNING) << "--force_select_tf_ops should always be used with " |
310 | "--enable_select_tf_ops." ; |
311 | } |
312 | |
313 | // Deprecated flag handling. |
314 | if (parsed_toco_flags.input_type.specified()) { |
315 | LOG(WARNING) |
316 | << "--input_type is deprecated. It was an ambiguous flag that set both " |
317 | "--input_data_types and --inference_input_type. If you are trying " |
318 | "to complement the input file with information about the type of " |
319 | "input arrays, use --input_data_type. If you are trying to control " |
320 | "the quantization/dequantization of real-numbers input arrays in " |
321 | "the output file, use --inference_input_type." ; |
322 | toco::IODataType input_type; |
323 | QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(), |
324 | &input_type)); |
325 | toco_flags->set_inference_input_type(input_type); |
326 | } |
327 | if (parsed_toco_flags.input_types.specified()) { |
328 | LOG(WARNING) |
329 | << "--input_types is deprecated. It was an ambiguous flag that set " |
330 | "both --input_data_types and --inference_input_type. If you are " |
331 | "trying to complement the input file with information about the " |
332 | "type of input arrays, use --input_data_type. If you are trying to " |
333 | "control the quantization/dequantization of real-numbers input " |
334 | "arrays in the output file, use --inference_input_type." ; |
335 | std::vector<std::string> input_types = |
336 | absl::StrSplit(parsed_toco_flags.input_types.value(), ','); |
337 | QCHECK(!input_types.empty()); |
338 | for (size_t i = 1; i < input_types.size(); i++) { |
339 | QCHECK_EQ(input_types[i], input_types[0]); |
340 | } |
341 | toco::IODataType input_type; |
342 | QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); |
343 | toco_flags->set_inference_input_type(input_type); |
344 | } |
345 | if (parsed_toco_flags.quantize_weights.value()) { |
346 | LOG(WARNING) |
347 | << "--quantize_weights is deprecated. Falling back to " |
348 | "--post_training_quantize. Please switch --post_training_quantize." ; |
349 | toco_flags->set_post_training_quantize( |
350 | parsed_toco_flags.quantize_weights.value()); |
351 | } |
352 | if (parsed_toco_flags.quantize_weights.value()) { |
353 | if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) { |
354 | LOG(WARNING) |
355 | << "--post_training_quantize quantizes a graph of inference_type " |
356 | "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT." ; |
357 | toco_flags->set_inference_type(IODataType::FLOAT); |
358 | } |
359 | } |
360 | |
361 | #undef READ_TOCO_FLAG |
362 | #undef PARSE_TOCO_FLAG |
363 | } |
364 | } // namespace toco |
365 | |