1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/toco/model_cmdline_flags.h" |
16 | |
17 | #include <string> |
18 | #include <vector> |
19 | |
20 | #include "absl/strings/numbers.h" |
21 | #include "absl/strings/str_join.h" |
22 | #include "absl/strings/str_split.h" |
23 | #include "absl/strings/string_view.h" |
24 | #include "absl/strings/strip.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/util/command_line_flags.h" |
27 | #include "tensorflow/lite/toco/args.h" |
28 | #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" |
29 | #include "tensorflow/lite/toco/toco_port.h" |
30 | |
31 | // "batch" flag only exists internally |
32 | #ifdef PLATFORM_GOOGLE |
33 | #include "base/commandlineflags.h" |
34 | #endif |
35 | |
36 | namespace toco { |
37 | |
38 | bool ParseModelFlagsFromCommandLineFlags( |
39 | int* argc, char* argv[], std::string* msg, |
40 | ParsedModelFlags* parsed_model_flags_ptr) { |
41 | ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr; |
42 | using tensorflow::Flag; |
43 | std::vector<tensorflow::Flag> flags = { |
44 | Flag("input_array" , parsed_flags.input_array.bind(), |
45 | parsed_flags.input_array.default_value(), |
46 | "Deprecated: use --input_arrays instead. Name of the input array. " |
47 | "If not specified, will try to read " |
48 | "that information from the input file." ), |
49 | Flag("input_arrays" , parsed_flags.input_arrays.bind(), |
50 | parsed_flags.input_arrays.default_value(), |
51 | "Names of the input arrays, comma-separated. If not specified, " |
52 | "will try to read that information from the input file." ), |
53 | Flag("output_array" , parsed_flags.output_array.bind(), |
54 | parsed_flags.output_array.default_value(), |
55 | "Deprecated: use --output_arrays instead. Name of the output array, " |
56 | "when specifying a unique output array. " |
57 | "If not specified, will try to read that information from the " |
58 | "input file." ), |
59 | Flag("output_arrays" , parsed_flags.output_arrays.bind(), |
60 | parsed_flags.output_arrays.default_value(), |
61 | "Names of the output arrays, comma-separated. " |
62 | "If not specified, will try to read " |
63 | "that information from the input file." ), |
64 | Flag("input_shape" , parsed_flags.input_shape.bind(), |
65 | parsed_flags.input_shape.default_value(), |
66 | "Deprecated: use --input_shapes instead. Input array shape. For " |
67 | "many models the shape takes the form " |
68 | "batch size, input array height, input array width, input array " |
69 | "depth." ), |
70 | Flag("input_shapes" , parsed_flags.input_shapes.bind(), |
71 | parsed_flags.input_shapes.default_value(), |
72 | "Shapes corresponding to --input_arrays, colon-separated. For " |
73 | "many models each shape takes the form batch size, input array " |
74 | "height, input array width, input array depth." ), |
75 | Flag("batch_size" , parsed_flags.batch_size.bind(), |
76 | parsed_flags.batch_size.default_value(), |
77 | "Deprecated. Batch size for the model. Replaces the first dimension " |
78 | "of an input size array if undefined. Use only with SavedModels " |
79 | "when --input_shapes flag is not specified. Always use " |
80 | "--input_shapes flag with frozen graphs." ), |
81 | Flag("input_data_type" , parsed_flags.input_data_type.bind(), |
82 | parsed_flags.input_data_type.default_value(), |
83 | "Deprecated: use --input_data_types instead. Input array type, if " |
84 | "not already provided in the graph. " |
85 | "Typically needs to be specified when passing arbitrary arrays " |
86 | "to --input_arrays." ), |
87 | Flag("input_data_types" , parsed_flags.input_data_types.bind(), |
88 | parsed_flags.input_data_types.default_value(), |
89 | "Input arrays types, comma-separated, if not already provided in " |
90 | "the graph. " |
91 | "Typically needs to be specified when passing arbitrary arrays " |
92 | "to --input_arrays." ), |
93 | Flag("mean_value" , parsed_flags.mean_value.bind(), |
94 | parsed_flags.mean_value.default_value(), |
95 | "Deprecated: use --mean_values instead. mean_value parameter for " |
96 | "image models, used to compute input " |
97 | "activations from input pixel data." ), |
98 | Flag("mean_values" , parsed_flags.mean_values.bind(), |
99 | parsed_flags.mean_values.default_value(), |
100 | "mean_values parameter for image models, comma-separated list of " |
101 | "doubles, used to compute input activations from input pixel " |
102 | "data. Each entry in the list should match an entry in " |
103 | "--input_arrays." ), |
104 | Flag("std_value" , parsed_flags.std_value.bind(), |
105 | parsed_flags.std_value.default_value(), |
106 | "Deprecated: use --std_values instead. std_value parameter for " |
107 | "image models, used to compute input " |
108 | "activations from input pixel data." ), |
109 | Flag("std_values" , parsed_flags.std_values.bind(), |
110 | parsed_flags.std_values.default_value(), |
111 | "std_value parameter for image models, comma-separated list of " |
112 | "doubles, used to compute input activations from input pixel " |
113 | "data. Each entry in the list should match an entry in " |
114 | "--input_arrays." ), |
115 | Flag("variable_batch" , parsed_flags.variable_batch.bind(), |
116 | parsed_flags.variable_batch.default_value(), |
117 | "If true, the model accepts an arbitrary batch size. Mutually " |
118 | "exclusive " |
119 | "with the 'batch' field: at most one of these two fields can be " |
120 | "set." ), |
121 | Flag("rnn_states" , parsed_flags.rnn_states.bind(), |
122 | parsed_flags.rnn_states.default_value(), "" ), |
123 | Flag("model_checks" , parsed_flags.model_checks.bind(), |
124 | parsed_flags.model_checks.default_value(), |
125 | "A list of model checks to be applied to verify the form of the " |
126 | "model. Applied after the graph transformations after import." ), |
127 | Flag("dump_graphviz" , parsed_flags.dump_graphviz.bind(), |
128 | parsed_flags.dump_graphviz.default_value(), |
129 | "Dump graphviz during LogDump call. If string is non-empty then " |
130 | "it defines path to dump, otherwise will skip dumping." ), |
131 | Flag("dump_graphviz_video" , parsed_flags.dump_graphviz_video.bind(), |
132 | parsed_flags.dump_graphviz_video.default_value(), |
133 | "If true, will dump graphviz at each " |
134 | "graph transformation, which may be used to generate a video." ), |
135 | Flag("conversion_summary_dir" , parsed_flags.conversion_summary_dir.bind(), |
136 | parsed_flags.conversion_summary_dir.default_value(), |
137 | "Local file directory to store the conversion logs." ), |
138 | Flag("allow_nonexistent_arrays" , |
139 | parsed_flags.allow_nonexistent_arrays.bind(), |
140 | parsed_flags.allow_nonexistent_arrays.default_value(), |
141 | "If true, will allow passing inexistent arrays in --input_arrays " |
142 | "and --output_arrays. This makes little sense, is only useful to " |
143 | "more easily get graph visualizations." ), |
144 | Flag("allow_nonascii_arrays" , parsed_flags.allow_nonascii_arrays.bind(), |
145 | parsed_flags.allow_nonascii_arrays.default_value(), |
146 | "If true, will allow passing non-ascii-printable characters in " |
147 | "--input_arrays and --output_arrays. By default (if false), only " |
148 | "ascii printable characters are allowed, i.e. character codes " |
149 | "ranging from 32 to 127. This is disallowed by default so as to " |
150 | "catch common copy-and-paste issues where invisible unicode " |
151 | "characters are unwittingly added to these strings." ), |
152 | Flag( |
153 | "arrays_extra_info_file" , parsed_flags.arrays_extra_info_file.bind(), |
154 | parsed_flags.arrays_extra_info_file.default_value(), |
155 | "Path to an optional file containing a serialized ArraysExtraInfo " |
156 | "proto allowing to pass extra information about arrays not specified " |
157 | "in the input model file, such as extra MinMax information." ), |
158 | Flag("model_flags_file" , parsed_flags.model_flags_file.bind(), |
159 | parsed_flags.model_flags_file.default_value(), |
160 | "Path to an optional file containing a serialized ModelFlags proto. " |
161 | "Options specified on the command line will override the values in " |
162 | "the proto." ), |
163 | Flag("change_concat_input_ranges" , |
164 | parsed_flags.change_concat_input_ranges.bind(), |
165 | parsed_flags.change_concat_input_ranges.default_value(), |
166 | "Boolean to change the behavior of min/max ranges for inputs and" |
167 | " output of the concat operators." ), |
168 | }; |
169 | bool asked_for_help = |
170 | *argc == 2 && (!strcmp(argv[1], "--help" ) || !strcmp(argv[1], "-help" )); |
171 | if (asked_for_help) { |
172 | *msg += tensorflow::Flags::Usage(argv[0], flags); |
173 | return false; |
174 | } else { |
175 | if (!tensorflow::Flags::Parse(argc, argv, flags)) return false; |
176 | } |
177 | auto& dump_options = *GraphVizDumpOptions::singleton(); |
178 | dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value(); |
179 | dump_options.dump_graphviz = parsed_flags.dump_graphviz.value(); |
180 | |
181 | return true; |
182 | } |
183 | |
184 | void ReadModelFlagsFromCommandLineFlags( |
185 | const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) { |
186 | toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet" ); |
187 | |
188 | // Load proto containing the initial model flags. |
189 | // Additional flags specified on the command line will overwrite the values. |
190 | if (parsed_model_flags.model_flags_file.specified()) { |
191 | std::string model_flags_file_contents; |
192 | QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(), |
193 | &model_flags_file_contents, |
194 | port::file::Defaults()) |
195 | .ok()) |
196 | << "Specified --model_flags_file=" |
197 | << parsed_model_flags.model_flags_file.value() |
198 | << " was not found or could not be read" ; |
199 | QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents, |
200 | model_flags)) |
201 | << "Specified --model_flags_file=" |
202 | << parsed_model_flags.model_flags_file.value() |
203 | << " could not be parsed" ; |
204 | } |
205 | |
206 | #ifdef PLATFORM_GOOGLE |
207 | CHECK(!((base::WasPresentOnCommandLine("batch" ) && |
208 | parsed_model_flags.variable_batch.specified()))) |
209 | << "The --batch and --variable_batch flags are mutually exclusive." ; |
210 | #endif |
211 | CHECK(!(parsed_model_flags.output_array.specified() && |
212 | parsed_model_flags.output_arrays.specified())) |
213 | << "The --output_array and --vs flags are mutually exclusive." ; |
214 | |
215 | if (parsed_model_flags.output_array.specified()) { |
216 | model_flags->add_output_arrays(parsed_model_flags.output_array.value()); |
217 | } |
218 | |
219 | if (parsed_model_flags.output_arrays.specified()) { |
220 | std::vector<std::string> output_arrays = |
221 | absl::StrSplit(parsed_model_flags.output_arrays.value(), ','); |
222 | for (const std::string& output_array : output_arrays) { |
223 | model_flags->add_output_arrays(output_array); |
224 | } |
225 | } |
226 | |
227 | const bool uses_single_input_flags = |
228 | parsed_model_flags.input_array.specified() || |
229 | parsed_model_flags.mean_value.specified() || |
230 | parsed_model_flags.std_value.specified() || |
231 | parsed_model_flags.input_shape.specified(); |
232 | |
233 | const bool uses_multi_input_flags = |
234 | parsed_model_flags.input_arrays.specified() || |
235 | parsed_model_flags.mean_values.specified() || |
236 | parsed_model_flags.std_values.specified() || |
237 | parsed_model_flags.input_shapes.specified(); |
238 | |
239 | QCHECK(!(uses_single_input_flags && uses_multi_input_flags)) |
240 | << "Use either the singular-form input flags (--input_array, " |
241 | "--input_shape, --mean_value, --std_value) or the plural form input " |
242 | "flags (--input_arrays, --input_shapes, --mean_values, --std_values), " |
243 | "but not both forms within the same command line." ; |
244 | |
245 | if (parsed_model_flags.input_array.specified()) { |
246 | QCHECK(uses_single_input_flags); |
247 | model_flags->add_input_arrays()->set_name( |
248 | parsed_model_flags.input_array.value()); |
249 | } |
250 | if (parsed_model_flags.input_arrays.specified()) { |
251 | QCHECK(uses_multi_input_flags); |
252 | for (const auto& input_array : |
253 | absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) { |
254 | model_flags->add_input_arrays()->set_name(std::string(input_array)); |
255 | } |
256 | } |
257 | if (parsed_model_flags.mean_value.specified()) { |
258 | QCHECK(uses_single_input_flags); |
259 | model_flags->mutable_input_arrays(0)->set_mean_value( |
260 | parsed_model_flags.mean_value.value()); |
261 | } |
262 | if (parsed_model_flags.mean_values.specified()) { |
263 | QCHECK(uses_multi_input_flags); |
264 | std::vector<std::string> mean_values = |
265 | absl::StrSplit(parsed_model_flags.mean_values.value(), ','); |
266 | QCHECK(static_cast<int>(mean_values.size()) == |
267 | model_flags->input_arrays_size()); |
268 | for (size_t i = 0; i < mean_values.size(); ++i) { |
269 | char* last = nullptr; |
270 | model_flags->mutable_input_arrays(i)->set_mean_value( |
271 | strtod(mean_values[i].data(), &last)); |
272 | CHECK(last != mean_values[i].data()); |
273 | } |
274 | } |
275 | if (parsed_model_flags.std_value.specified()) { |
276 | QCHECK(uses_single_input_flags); |
277 | model_flags->mutable_input_arrays(0)->set_std_value( |
278 | parsed_model_flags.std_value.value()); |
279 | } |
280 | if (parsed_model_flags.std_values.specified()) { |
281 | QCHECK(uses_multi_input_flags); |
282 | std::vector<std::string> std_values = |
283 | absl::StrSplit(parsed_model_flags.std_values.value(), ','); |
284 | QCHECK(static_cast<int>(std_values.size()) == |
285 | model_flags->input_arrays_size()); |
286 | for (size_t i = 0; i < std_values.size(); ++i) { |
287 | char* last = nullptr; |
288 | model_flags->mutable_input_arrays(i)->set_std_value( |
289 | strtod(std_values[i].data(), &last)); |
290 | CHECK(last != std_values[i].data()); |
291 | } |
292 | } |
293 | if (parsed_model_flags.input_data_type.specified()) { |
294 | QCHECK(uses_single_input_flags); |
295 | IODataType type; |
296 | QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type)); |
297 | model_flags->mutable_input_arrays(0)->set_data_type(type); |
298 | } |
299 | if (parsed_model_flags.input_data_types.specified()) { |
300 | QCHECK(uses_multi_input_flags); |
301 | std::vector<std::string> input_data_types = |
302 | absl::StrSplit(parsed_model_flags.input_data_types.value(), ','); |
303 | QCHECK(static_cast<int>(input_data_types.size()) == |
304 | model_flags->input_arrays_size()); |
305 | for (size_t i = 0; i < input_data_types.size(); ++i) { |
306 | IODataType type; |
307 | QCHECK(IODataType_Parse(input_data_types[i], &type)); |
308 | model_flags->mutable_input_arrays(i)->set_data_type(type); |
309 | } |
310 | } |
311 | if (parsed_model_flags.input_shape.specified()) { |
312 | QCHECK(uses_single_input_flags); |
313 | if (model_flags->input_arrays().empty()) { |
314 | model_flags->add_input_arrays(); |
315 | } |
316 | auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape(); |
317 | shape->clear_dims(); |
318 | const IntList& list = parsed_model_flags.input_shape.value(); |
319 | for (auto& dim : list.elements) { |
320 | shape->add_dims(dim); |
321 | } |
322 | } |
323 | if (parsed_model_flags.input_shapes.specified()) { |
324 | QCHECK(uses_multi_input_flags); |
325 | std::vector<std::string> input_shapes = |
326 | absl::StrSplit(parsed_model_flags.input_shapes.value(), ':'); |
327 | QCHECK(static_cast<int>(input_shapes.size()) == |
328 | model_flags->input_arrays_size()); |
329 | for (size_t i = 0; i < input_shapes.size(); ++i) { |
330 | auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape(); |
331 | shape->clear_dims(); |
332 | // Treat an empty input shape as a scalar. |
333 | if (input_shapes[i].empty()) { |
334 | continue; |
335 | } |
336 | for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) { |
337 | int size; |
338 | CHECK(absl::SimpleAtoi(dim_str, &size)) |
339 | << "Failed to parse input_shape: " << input_shapes[i]; |
340 | shape->add_dims(size); |
341 | } |
342 | } |
343 | } |
344 | |
345 | #define READ_MODEL_FLAG(name) \ |
346 | do { \ |
347 | if (parsed_model_flags.name.specified()) { \ |
348 | model_flags->set_##name(parsed_model_flags.name.value()); \ |
349 | } \ |
350 | } while (false) |
351 | |
352 | READ_MODEL_FLAG(variable_batch); |
353 | |
354 | #undef READ_MODEL_FLAG |
355 | |
356 | for (const auto& element : parsed_model_flags.rnn_states.value().elements) { |
357 | auto* rnn_state_proto = model_flags->add_rnn_states(); |
358 | for (const auto& kv_pair : element) { |
359 | const std::string& key = kv_pair.first; |
360 | const std::string& value = kv_pair.second; |
361 | if (key == "state_array" ) { |
362 | rnn_state_proto->set_state_array(value); |
363 | } else if (key == "back_edge_source_array" ) { |
364 | rnn_state_proto->set_back_edge_source_array(value); |
365 | } else if (key == "size" ) { |
366 | int32_t size = 0; |
367 | CHECK(absl::SimpleAtoi(value, &size)); |
368 | CHECK_GT(size, 0); |
369 | rnn_state_proto->set_size(size); |
370 | } else if (key == "num_dims" ) { |
371 | int32_t size = 0; |
372 | CHECK(absl::SimpleAtoi(value, &size)); |
373 | CHECK_GT(size, 0); |
374 | rnn_state_proto->set_num_dims(size); |
375 | } else { |
376 | LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states" ; |
377 | } |
378 | } |
379 | CHECK(rnn_state_proto->has_state_array() && |
380 | rnn_state_proto->has_back_edge_source_array() && |
381 | rnn_state_proto->has_size()) |
382 | << "--rnn_states must include state_array, back_edge_source_array and " |
383 | "size." ; |
384 | } |
385 | |
386 | for (const auto& element : parsed_model_flags.model_checks.value().elements) { |
387 | auto* model_check_proto = model_flags->add_model_checks(); |
388 | for (const auto& kv_pair : element) { |
389 | const std::string& key = kv_pair.first; |
390 | const std::string& value = kv_pair.second; |
391 | if (key == "count_type" ) { |
392 | model_check_proto->set_count_type(value); |
393 | } else if (key == "count_min" ) { |
394 | int32_t count = 0; |
395 | CHECK(absl::SimpleAtoi(value, &count)); |
396 | CHECK_GE(count, -1); |
397 | model_check_proto->set_count_min(count); |
398 | } else if (key == "count_max" ) { |
399 | int32_t count = 0; |
400 | CHECK(absl::SimpleAtoi(value, &count)); |
401 | CHECK_GE(count, -1); |
402 | model_check_proto->set_count_max(count); |
403 | } else { |
404 | LOG(FATAL) << "Unknown key '" << key << "' in --model_checks" ; |
405 | } |
406 | } |
407 | } |
408 | |
409 | if (!model_flags->has_allow_nonascii_arrays()) { |
410 | model_flags->set_allow_nonascii_arrays( |
411 | parsed_model_flags.allow_nonascii_arrays.value()); |
412 | } |
413 | if (!model_flags->has_allow_nonexistent_arrays()) { |
414 | model_flags->set_allow_nonexistent_arrays( |
415 | parsed_model_flags.allow_nonexistent_arrays.value()); |
416 | } |
417 | if (!model_flags->has_change_concat_input_ranges()) { |
418 | model_flags->set_change_concat_input_ranges( |
419 | parsed_model_flags.change_concat_input_ranges.value()); |
420 | } |
421 | |
422 | if (parsed_model_flags.arrays_extra_info_file.specified()) { |
423 | std::string ; |
424 | CHECK(port::file::GetContents( |
425 | parsed_model_flags.arrays_extra_info_file.value(), |
426 | &arrays_extra_info_file_contents, port::file::Defaults()) |
427 | .ok()); |
428 | ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents, |
429 | model_flags->mutable_arrays_extra_info()); |
430 | } |
431 | } |
432 | |
433 | ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) { |
434 | static auto* flags = [must_already_exist]() { |
435 | if (must_already_exist) { |
436 | fprintf(stderr, __FILE__ |
437 | ":" |
438 | "GlobalParsedModelFlags() used without initialization\n" ); |
439 | fflush(stderr); |
440 | abort(); |
441 | } |
442 | return new toco::ParsedModelFlags; |
443 | }(); |
444 | return flags; |
445 | } |
446 | |
447 | ParsedModelFlags* GlobalParsedModelFlags() { |
448 | return UncheckedGlobalParsedModelFlags(true); |
449 | } |
450 | |
451 | void ParseModelFlagsOrDie(int* argc, char* argv[]) { |
452 | // TODO(aselle): in the future allow Google version to use |
453 | // flags, and only use this mechanism for open source |
454 | auto* flags = UncheckedGlobalParsedModelFlags(false); |
455 | std::string msg; |
456 | bool model_success = |
457 | toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags); |
458 | if (!model_success || !msg.empty()) { |
459 | // Log in non-standard way since this happens pre InitGoogle. |
460 | fprintf(stderr, "%s" , msg.c_str()); |
461 | fflush(stderr); |
462 | abort(); |
463 | } |
464 | } |
465 | |
466 | } // namespace toco |
467 | |