1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// This abstracts command line arguments in toco.
16// Arg<T> is a parseable type that can register a default value, be able to
17// parse itself, and keep track of whether it was specified.
18#ifndef TENSORFLOW_LITE_TOCO_ARGS_H_
19#define TENSORFLOW_LITE_TOCO_ARGS_H_
20
21#include <functional>
22#include <string>
23#include <unordered_map>
24#include <vector>
25
26#include "absl/strings/numbers.h"
27#include "absl/strings/str_split.h"
28#include "tensorflow/lite/toco/toco_port.h"
29#include "tensorflow/lite/toco/toco_types.h"
30
31namespace toco {
32
33// Since std::vector<int32> is in the std namespace, and we are not allowed
34// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
35// to use as the flag type:
36struct IntList {
37 std::vector<int32> elements;
38};
39struct StringMapList {
40 std::vector<std::unordered_map<std::string, std::string>> elements;
41};
42
43// command_line_flags.h don't track whether or not a flag is specified. Arg
44// contains the value (which will be default if not specified) and also
45// whether the flag is specified.
46// TODO(aselle): consider putting doc string and ability to construct the
47// tensorflow argument into this, so declaration of parameters can be less
48// distributed.
49// Every template specialization of Arg is required to implement
50// default_value(), specified(), value(), parse(), bind().
51template <class T>
52class Arg final {
53 public:
54 explicit Arg(T default_ = T()) : value_(default_) {}
55 virtual ~Arg() {}
56
57 // Provide default_value() to arg list
58 T default_value() const { return value_; }
59 // Return true if the command line argument was specified on the command line.
60 bool specified() const { return specified_; }
61 // Const reference to parsed value.
62 const T& value() const { return value_; }
63
64 // Parsing callback for the tensorflow::Flags code
65 bool Parse(T value_in) {
66 value_ = value_in;
67 specified_ = true;
68 return true;
69 }
70
71 // Bind the parse member function so tensorflow::Flags can call it.
72 std::function<bool(T)> bind() {
73 return std::bind(&Arg::Parse, this, std::placeholders::_1);
74 }
75
76 private:
77 // Becomes true after parsing if the value was specified
78 bool specified_ = false;
79 // Value of the argument (initialized to the default in the constructor).
80 T value_;
81};
82
83template <>
84class Arg<toco::IntList> final {
85 public:
86 // Provide default_value() to arg list
87 std::string default_value() const { return ""; }
88 // Return true if the command line argument was specified on the command line.
89 bool specified() const { return specified_; }
90 // Bind the parse member function so tensorflow::Flags can call it.
91 bool Parse(std::string text);
92
93 std::function<bool(std::string)> bind() {
94 return std::bind(&Arg::Parse, this, std::placeholders::_1);
95 }
96
97 const toco::IntList& value() const { return parsed_value_; }
98
99 private:
100 toco::IntList parsed_value_;
101 bool specified_ = false;
102};
103
104template <>
105class Arg<toco::StringMapList> final {
106 public:
107 // Provide default_value() to StringMapList
108 std::string default_value() const { return ""; }
109 // Return true if the command line argument was specified on the command line.
110 bool specified() const { return specified_; }
111 // Bind the parse member function so tensorflow::Flags can call it.
112
113 bool Parse(std::string text);
114
115 std::function<bool(std::string)> bind() {
116 return std::bind(&Arg::Parse, this, std::placeholders::_1);
117 }
118
119 const toco::StringMapList& value() const { return parsed_value_; }
120
121 private:
122 toco::StringMapList parsed_value_;
123 bool specified_ = false;
124};
125
126// Flags that describe a model. See model_cmdline_flags.cc for details.
127struct ParsedModelFlags {
128 Arg<std::string> input_array;
129 Arg<std::string> input_arrays;
130 Arg<std::string> output_array;
131 Arg<std::string> output_arrays;
132 Arg<std::string> input_shapes;
133 Arg<int> batch_size = Arg<int>(1);
134 Arg<float> mean_value = Arg<float>(0.f);
135 Arg<std::string> mean_values;
136 Arg<float> std_value = Arg<float>(1.f);
137 Arg<std::string> std_values;
138 Arg<std::string> input_data_type;
139 Arg<std::string> input_data_types;
140 Arg<bool> variable_batch = Arg<bool>(false);
141 Arg<toco::IntList> input_shape;
142 Arg<toco::StringMapList> rnn_states;
143 Arg<toco::StringMapList> model_checks;
144 Arg<bool> change_concat_input_ranges = Arg<bool>(true);
145 // Debugging output options.
146 // TODO(benoitjacob): these shouldn't be ModelFlags.
147 Arg<std::string> graphviz_first_array;
148 Arg<std::string> graphviz_last_array;
149 Arg<std::string> dump_graphviz;
150 Arg<bool> dump_graphviz_video = Arg<bool>(false);
151 Arg<std::string> conversion_summary_dir;
152 Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
153 Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
154 Arg<std::string> arrays_extra_info_file;
155 Arg<std::string> model_flags_file;
156};
157
158// Flags that describe the operation you would like to do (what conversion
159// you want). See toco_cmdline_flags.cc for details.
160struct ParsedTocoFlags {
161 Arg<std::string> input_file;
162 Arg<std::string> savedmodel_directory;
163 Arg<std::string> output_file;
164 Arg<std::string> input_format = Arg<std::string>("TENSORFLOW_GRAPHDEF");
165 Arg<std::string> output_format = Arg<std::string>("TFLITE");
166 Arg<std::string> savedmodel_tagset;
167 // TODO(aselle): command_line_flags doesn't support doubles
168 Arg<float> default_ranges_min = Arg<float>(0.);
169 Arg<float> default_ranges_max = Arg<float>(0.);
170 Arg<float> default_int16_ranges_min = Arg<float>(0.);
171 Arg<float> default_int16_ranges_max = Arg<float>(0.);
172 Arg<std::string> inference_type;
173 Arg<std::string> inference_input_type;
174 Arg<bool> drop_fake_quant = Arg<bool>(false);
175 Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
176 Arg<bool> allow_custom_ops = Arg<bool>(false);
177 Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
178 Arg<std::string> custom_opdefs;
179 Arg<bool> post_training_quantize = Arg<bool>(false);
180 Arg<bool> quantize_to_float16 = Arg<bool>(false);
181 // Deprecated flags
182 Arg<bool> quantize_weights = Arg<bool>(false);
183 Arg<std::string> input_type;
184 Arg<std::string> input_types;
185 Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
186 Arg<bool> drop_control_dependency = Arg<bool>(false);
187 Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
188 Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
189 Arg<int64_t> dedupe_array_min_size_bytes = Arg<int64_t>(64);
190 Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
191 // WARNING: Experimental interface, subject to change
192 Arg<bool> enable_select_tf_ops = Arg<bool>(false);
193 // WARNING: Experimental interface, subject to change
194 Arg<bool> force_select_tf_ops = Arg<bool>(false);
195 // WARNING: Experimental interface, subject to change
196 Arg<bool> unfold_batchmatmul = Arg<bool>(true);
197 // WARNING: Experimental interface, subject to change
198 Arg<std::string> accumulation_type;
199 // WARNING: Experimental interface, subject to change
200 Arg<bool> allow_bfloat16;
201};
202
203} // namespace toco
204#endif // TENSORFLOW_LITE_TOCO_ARGS_H_
205