1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/toco/model_cmdline_flags.h"
16
17#include <string>
18#include <vector>
19
20#include "absl/strings/numbers.h"
21#include "absl/strings/str_join.h"
22#include "absl/strings/str_split.h"
23#include "absl/strings/string_view.h"
24#include "absl/strings/strip.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/util/command_line_flags.h"
27#include "tensorflow/lite/toco/args.h"
28#include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
29#include "tensorflow/lite/toco/toco_port.h"
30
31// "batch" flag only exists internally
32#ifdef PLATFORM_GOOGLE
33#include "base/commandlineflags.h"
34#endif
35
36namespace toco {
37
38bool ParseModelFlagsFromCommandLineFlags(
39 int* argc, char* argv[], std::string* msg,
40 ParsedModelFlags* parsed_model_flags_ptr) {
41 ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42 using tensorflow::Flag;
43 std::vector<tensorflow::Flag> flags = {
44 Flag("input_array", parsed_flags.input_array.bind(),
45 parsed_flags.input_array.default_value(),
46 "Deprecated: use --input_arrays instead. Name of the input array. "
47 "If not specified, will try to read "
48 "that information from the input file."),
49 Flag("input_arrays", parsed_flags.input_arrays.bind(),
50 parsed_flags.input_arrays.default_value(),
51 "Names of the input arrays, comma-separated. If not specified, "
52 "will try to read that information from the input file."),
53 Flag("output_array", parsed_flags.output_array.bind(),
54 parsed_flags.output_array.default_value(),
55 "Deprecated: use --output_arrays instead. Name of the output array, "
56 "when specifying a unique output array. "
57 "If not specified, will try to read that information from the "
58 "input file."),
59 Flag("output_arrays", parsed_flags.output_arrays.bind(),
60 parsed_flags.output_arrays.default_value(),
61 "Names of the output arrays, comma-separated. "
62 "If not specified, will try to read "
63 "that information from the input file."),
64 Flag("input_shape", parsed_flags.input_shape.bind(),
65 parsed_flags.input_shape.default_value(),
66 "Deprecated: use --input_shapes instead. Input array shape. For "
67 "many models the shape takes the form "
68 "batch size, input array height, input array width, input array "
69 "depth."),
70 Flag("input_shapes", parsed_flags.input_shapes.bind(),
71 parsed_flags.input_shapes.default_value(),
72 "Shapes corresponding to --input_arrays, colon-separated. For "
73 "many models each shape takes the form batch size, input array "
74 "height, input array width, input array depth."),
75 Flag("batch_size", parsed_flags.batch_size.bind(),
76 parsed_flags.batch_size.default_value(),
77 "Deprecated. Batch size for the model. Replaces the first dimension "
78 "of an input size array if undefined. Use only with SavedModels "
79 "when --input_shapes flag is not specified. Always use "
80 "--input_shapes flag with frozen graphs."),
81 Flag("input_data_type", parsed_flags.input_data_type.bind(),
82 parsed_flags.input_data_type.default_value(),
83 "Deprecated: use --input_data_types instead. Input array type, if "
84 "not already provided in the graph. "
85 "Typically needs to be specified when passing arbitrary arrays "
86 "to --input_arrays."),
87 Flag("input_data_types", parsed_flags.input_data_types.bind(),
88 parsed_flags.input_data_types.default_value(),
89 "Input arrays types, comma-separated, if not already provided in "
90 "the graph. "
91 "Typically needs to be specified when passing arbitrary arrays "
92 "to --input_arrays."),
93 Flag("mean_value", parsed_flags.mean_value.bind(),
94 parsed_flags.mean_value.default_value(),
95 "Deprecated: use --mean_values instead. mean_value parameter for "
96 "image models, used to compute input "
97 "activations from input pixel data."),
98 Flag("mean_values", parsed_flags.mean_values.bind(),
99 parsed_flags.mean_values.default_value(),
100 "mean_values parameter for image models, comma-separated list of "
101 "doubles, used to compute input activations from input pixel "
102 "data. Each entry in the list should match an entry in "
103 "--input_arrays."),
104 Flag("std_value", parsed_flags.std_value.bind(),
105 parsed_flags.std_value.default_value(),
106 "Deprecated: use --std_values instead. std_value parameter for "
107 "image models, used to compute input "
108 "activations from input pixel data."),
109 Flag("std_values", parsed_flags.std_values.bind(),
110 parsed_flags.std_values.default_value(),
111 "std_value parameter for image models, comma-separated list of "
112 "doubles, used to compute input activations from input pixel "
113 "data. Each entry in the list should match an entry in "
114 "--input_arrays."),
115 Flag("variable_batch", parsed_flags.variable_batch.bind(),
116 parsed_flags.variable_batch.default_value(),
117 "If true, the model accepts an arbitrary batch size. Mutually "
118 "exclusive "
119 "with the 'batch' field: at most one of these two fields can be "
120 "set."),
121 Flag("rnn_states", parsed_flags.rnn_states.bind(),
122 parsed_flags.rnn_states.default_value(), ""),
123 Flag("model_checks", parsed_flags.model_checks.bind(),
124 parsed_flags.model_checks.default_value(),
125 "A list of model checks to be applied to verify the form of the "
126 "model. Applied after the graph transformations after import."),
127 Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128 parsed_flags.dump_graphviz.default_value(),
129 "Dump graphviz during LogDump call. If string is non-empty then "
130 "it defines path to dump, otherwise will skip dumping."),
131 Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132 parsed_flags.dump_graphviz_video.default_value(),
133 "If true, will dump graphviz at each "
134 "graph transformation, which may be used to generate a video."),
135 Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(),
136 parsed_flags.conversion_summary_dir.default_value(),
137 "Local file directory to store the conversion logs."),
138 Flag("allow_nonexistent_arrays",
139 parsed_flags.allow_nonexistent_arrays.bind(),
140 parsed_flags.allow_nonexistent_arrays.default_value(),
141 "If true, will allow passing inexistent arrays in --input_arrays "
142 "and --output_arrays. This makes little sense, is only useful to "
143 "more easily get graph visualizations."),
144 Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
145 parsed_flags.allow_nonascii_arrays.default_value(),
146 "If true, will allow passing non-ascii-printable characters in "
147 "--input_arrays and --output_arrays. By default (if false), only "
148 "ascii printable characters are allowed, i.e. character codes "
149 "ranging from 32 to 127. This is disallowed by default so as to "
150 "catch common copy-and-paste issues where invisible unicode "
151 "characters are unwittingly added to these strings."),
152 Flag(
153 "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
154 parsed_flags.arrays_extra_info_file.default_value(),
155 "Path to an optional file containing a serialized ArraysExtraInfo "
156 "proto allowing to pass extra information about arrays not specified "
157 "in the input model file, such as extra MinMax information."),
158 Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
159 parsed_flags.model_flags_file.default_value(),
160 "Path to an optional file containing a serialized ModelFlags proto. "
161 "Options specified on the command line will override the values in "
162 "the proto."),
163 Flag("change_concat_input_ranges",
164 parsed_flags.change_concat_input_ranges.bind(),
165 parsed_flags.change_concat_input_ranges.default_value(),
166 "Boolean to change the behavior of min/max ranges for inputs and"
167 " output of the concat operators."),
168 };
169 bool asked_for_help =
170 *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
171 if (asked_for_help) {
172 *msg += tensorflow::Flags::Usage(argv[0], flags);
173 return false;
174 } else {
175 if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
176 }
177 auto& dump_options = *GraphVizDumpOptions::singleton();
178 dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
179 dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
180
181 return true;
182}
183
184void ReadModelFlagsFromCommandLineFlags(
185 const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
186 toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
187
188 // Load proto containing the initial model flags.
189 // Additional flags specified on the command line will overwrite the values.
190 if (parsed_model_flags.model_flags_file.specified()) {
191 std::string model_flags_file_contents;
192 QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
193 &model_flags_file_contents,
194 port::file::Defaults())
195 .ok())
196 << "Specified --model_flags_file="
197 << parsed_model_flags.model_flags_file.value()
198 << " was not found or could not be read";
199 QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
200 model_flags))
201 << "Specified --model_flags_file="
202 << parsed_model_flags.model_flags_file.value()
203 << " could not be parsed";
204 }
205
206#ifdef PLATFORM_GOOGLE
207 CHECK(!((base::WasPresentOnCommandLine("batch") &&
208 parsed_model_flags.variable_batch.specified())))
209 << "The --batch and --variable_batch flags are mutually exclusive.";
210#endif
211 CHECK(!(parsed_model_flags.output_array.specified() &&
212 parsed_model_flags.output_arrays.specified()))
213 << "The --output_array and --vs flags are mutually exclusive.";
214
215 if (parsed_model_flags.output_array.specified()) {
216 model_flags->add_output_arrays(parsed_model_flags.output_array.value());
217 }
218
219 if (parsed_model_flags.output_arrays.specified()) {
220 std::vector<std::string> output_arrays =
221 absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
222 for (const std::string& output_array : output_arrays) {
223 model_flags->add_output_arrays(output_array);
224 }
225 }
226
227 const bool uses_single_input_flags =
228 parsed_model_flags.input_array.specified() ||
229 parsed_model_flags.mean_value.specified() ||
230 parsed_model_flags.std_value.specified() ||
231 parsed_model_flags.input_shape.specified();
232
233 const bool uses_multi_input_flags =
234 parsed_model_flags.input_arrays.specified() ||
235 parsed_model_flags.mean_values.specified() ||
236 parsed_model_flags.std_values.specified() ||
237 parsed_model_flags.input_shapes.specified();
238
239 QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
240 << "Use either the singular-form input flags (--input_array, "
241 "--input_shape, --mean_value, --std_value) or the plural form input "
242 "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
243 "but not both forms within the same command line.";
244
245 if (parsed_model_flags.input_array.specified()) {
246 QCHECK(uses_single_input_flags);
247 model_flags->add_input_arrays()->set_name(
248 parsed_model_flags.input_array.value());
249 }
250 if (parsed_model_flags.input_arrays.specified()) {
251 QCHECK(uses_multi_input_flags);
252 for (const auto& input_array :
253 absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
254 model_flags->add_input_arrays()->set_name(std::string(input_array));
255 }
256 }
257 if (parsed_model_flags.mean_value.specified()) {
258 QCHECK(uses_single_input_flags);
259 model_flags->mutable_input_arrays(0)->set_mean_value(
260 parsed_model_flags.mean_value.value());
261 }
262 if (parsed_model_flags.mean_values.specified()) {
263 QCHECK(uses_multi_input_flags);
264 std::vector<std::string> mean_values =
265 absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
266 QCHECK(static_cast<int>(mean_values.size()) ==
267 model_flags->input_arrays_size());
268 for (size_t i = 0; i < mean_values.size(); ++i) {
269 char* last = nullptr;
270 model_flags->mutable_input_arrays(i)->set_mean_value(
271 strtod(mean_values[i].data(), &last));
272 CHECK(last != mean_values[i].data());
273 }
274 }
275 if (parsed_model_flags.std_value.specified()) {
276 QCHECK(uses_single_input_flags);
277 model_flags->mutable_input_arrays(0)->set_std_value(
278 parsed_model_flags.std_value.value());
279 }
280 if (parsed_model_flags.std_values.specified()) {
281 QCHECK(uses_multi_input_flags);
282 std::vector<std::string> std_values =
283 absl::StrSplit(parsed_model_flags.std_values.value(), ',');
284 QCHECK(static_cast<int>(std_values.size()) ==
285 model_flags->input_arrays_size());
286 for (size_t i = 0; i < std_values.size(); ++i) {
287 char* last = nullptr;
288 model_flags->mutable_input_arrays(i)->set_std_value(
289 strtod(std_values[i].data(), &last));
290 CHECK(last != std_values[i].data());
291 }
292 }
293 if (parsed_model_flags.input_data_type.specified()) {
294 QCHECK(uses_single_input_flags);
295 IODataType type;
296 QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
297 model_flags->mutable_input_arrays(0)->set_data_type(type);
298 }
299 if (parsed_model_flags.input_data_types.specified()) {
300 QCHECK(uses_multi_input_flags);
301 std::vector<std::string> input_data_types =
302 absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
303 QCHECK(static_cast<int>(input_data_types.size()) ==
304 model_flags->input_arrays_size());
305 for (size_t i = 0; i < input_data_types.size(); ++i) {
306 IODataType type;
307 QCHECK(IODataType_Parse(input_data_types[i], &type));
308 model_flags->mutable_input_arrays(i)->set_data_type(type);
309 }
310 }
311 if (parsed_model_flags.input_shape.specified()) {
312 QCHECK(uses_single_input_flags);
313 if (model_flags->input_arrays().empty()) {
314 model_flags->add_input_arrays();
315 }
316 auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
317 shape->clear_dims();
318 const IntList& list = parsed_model_flags.input_shape.value();
319 for (auto& dim : list.elements) {
320 shape->add_dims(dim);
321 }
322 }
323 if (parsed_model_flags.input_shapes.specified()) {
324 QCHECK(uses_multi_input_flags);
325 std::vector<std::string> input_shapes =
326 absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
327 QCHECK(static_cast<int>(input_shapes.size()) ==
328 model_flags->input_arrays_size());
329 for (size_t i = 0; i < input_shapes.size(); ++i) {
330 auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
331 shape->clear_dims();
332 // Treat an empty input shape as a scalar.
333 if (input_shapes[i].empty()) {
334 continue;
335 }
336 for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
337 int size;
338 CHECK(absl::SimpleAtoi(dim_str, &size))
339 << "Failed to parse input_shape: " << input_shapes[i];
340 shape->add_dims(size);
341 }
342 }
343 }
344
345#define READ_MODEL_FLAG(name) \
346 do { \
347 if (parsed_model_flags.name.specified()) { \
348 model_flags->set_##name(parsed_model_flags.name.value()); \
349 } \
350 } while (false)
351
352 READ_MODEL_FLAG(variable_batch);
353
354#undef READ_MODEL_FLAG
355
356 for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
357 auto* rnn_state_proto = model_flags->add_rnn_states();
358 for (const auto& kv_pair : element) {
359 const std::string& key = kv_pair.first;
360 const std::string& value = kv_pair.second;
361 if (key == "state_array") {
362 rnn_state_proto->set_state_array(value);
363 } else if (key == "back_edge_source_array") {
364 rnn_state_proto->set_back_edge_source_array(value);
365 } else if (key == "size") {
366 int32_t size = 0;
367 CHECK(absl::SimpleAtoi(value, &size));
368 CHECK_GT(size, 0);
369 rnn_state_proto->set_size(size);
370 } else if (key == "num_dims") {
371 int32_t size = 0;
372 CHECK(absl::SimpleAtoi(value, &size));
373 CHECK_GT(size, 0);
374 rnn_state_proto->set_num_dims(size);
375 } else {
376 LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
377 }
378 }
379 CHECK(rnn_state_proto->has_state_array() &&
380 rnn_state_proto->has_back_edge_source_array() &&
381 rnn_state_proto->has_size())
382 << "--rnn_states must include state_array, back_edge_source_array and "
383 "size.";
384 }
385
386 for (const auto& element : parsed_model_flags.model_checks.value().elements) {
387 auto* model_check_proto = model_flags->add_model_checks();
388 for (const auto& kv_pair : element) {
389 const std::string& key = kv_pair.first;
390 const std::string& value = kv_pair.second;
391 if (key == "count_type") {
392 model_check_proto->set_count_type(value);
393 } else if (key == "count_min") {
394 int32_t count = 0;
395 CHECK(absl::SimpleAtoi(value, &count));
396 CHECK_GE(count, -1);
397 model_check_proto->set_count_min(count);
398 } else if (key == "count_max") {
399 int32_t count = 0;
400 CHECK(absl::SimpleAtoi(value, &count));
401 CHECK_GE(count, -1);
402 model_check_proto->set_count_max(count);
403 } else {
404 LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
405 }
406 }
407 }
408
409 if (!model_flags->has_allow_nonascii_arrays()) {
410 model_flags->set_allow_nonascii_arrays(
411 parsed_model_flags.allow_nonascii_arrays.value());
412 }
413 if (!model_flags->has_allow_nonexistent_arrays()) {
414 model_flags->set_allow_nonexistent_arrays(
415 parsed_model_flags.allow_nonexistent_arrays.value());
416 }
417 if (!model_flags->has_change_concat_input_ranges()) {
418 model_flags->set_change_concat_input_ranges(
419 parsed_model_flags.change_concat_input_ranges.value());
420 }
421
422 if (parsed_model_flags.arrays_extra_info_file.specified()) {
423 std::string arrays_extra_info_file_contents;
424 CHECK(port::file::GetContents(
425 parsed_model_flags.arrays_extra_info_file.value(),
426 &arrays_extra_info_file_contents, port::file::Defaults())
427 .ok());
428 ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
429 model_flags->mutable_arrays_extra_info());
430 }
431}
432
433ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
434 static auto* flags = [must_already_exist]() {
435 if (must_already_exist) {
436 fprintf(stderr, __FILE__
437 ":"
438 "GlobalParsedModelFlags() used without initialization\n");
439 fflush(stderr);
440 abort();
441 }
442 return new toco::ParsedModelFlags;
443 }();
444 return flags;
445}
446
447ParsedModelFlags* GlobalParsedModelFlags() {
448 return UncheckedGlobalParsedModelFlags(true);
449}
450
451void ParseModelFlagsOrDie(int* argc, char* argv[]) {
452 // TODO(aselle): in the future allow Google version to use
453 // flags, and only use this mechanism for open source
454 auto* flags = UncheckedGlobalParsedModelFlags(false);
455 std::string msg;
456 bool model_success =
457 toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
458 if (!model_success || !msg.empty()) {
459 // Log in non-standard way since this happens pre InitGoogle.
460 fprintf(stderr, "%s", msg.c_str());
461 fflush(stderr);
462 abort();
463 }
464}
465
466} // namespace toco
467