1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // This abstracts command line arguments in toco. |
16 | // Arg<T> is a parseable type that can register a default value, be able to |
17 | // parse itself, and keep track of whether it was specified. |
18 | #ifndef TENSORFLOW_LITE_TOCO_ARGS_H_ |
19 | #define TENSORFLOW_LITE_TOCO_ARGS_H_ |
20 | |
21 | #include <functional> |
22 | #include <string> |
23 | #include <unordered_map> |
24 | #include <vector> |
25 | |
26 | #include "absl/strings/numbers.h" |
27 | #include "absl/strings/str_split.h" |
28 | #include "tensorflow/lite/toco/toco_port.h" |
29 | #include "tensorflow/lite/toco/toco_types.h" |
30 | |
31 | namespace toco { |
32 | |
33 | // Since std::vector<int32> is in the std namespace, and we are not allowed |
34 | // to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type |
35 | // to use as the flag type: |
36 | struct IntList { |
37 | std::vector<int32> elements; |
38 | }; |
39 | struct StringMapList { |
40 | std::vector<std::unordered_map<std::string, std::string>> elements; |
41 | }; |
42 | |
43 | // command_line_flags.h don't track whether or not a flag is specified. Arg |
44 | // contains the value (which will be default if not specified) and also |
45 | // whether the flag is specified. |
46 | // TODO(aselle): consider putting doc string and ability to construct the |
47 | // tensorflow argument into this, so declaration of parameters can be less |
48 | // distributed. |
49 | // Every template specialization of Arg is required to implement |
50 | // default_value(), specified(), value(), parse(), bind(). |
51 | template <class T> |
52 | class Arg final { |
53 | public: |
54 | explicit Arg(T default_ = T()) : value_(default_) {} |
55 | virtual ~Arg() {} |
56 | |
57 | // Provide default_value() to arg list |
58 | T default_value() const { return value_; } |
59 | // Return true if the command line argument was specified on the command line. |
60 | bool specified() const { return specified_; } |
61 | // Const reference to parsed value. |
62 | const T& value() const { return value_; } |
63 | |
64 | // Parsing callback for the tensorflow::Flags code |
65 | bool Parse(T value_in) { |
66 | value_ = value_in; |
67 | specified_ = true; |
68 | return true; |
69 | } |
70 | |
71 | // Bind the parse member function so tensorflow::Flags can call it. |
72 | std::function<bool(T)> bind() { |
73 | return std::bind(&Arg::Parse, this, std::placeholders::_1); |
74 | } |
75 | |
76 | private: |
77 | // Becomes true after parsing if the value was specified |
78 | bool specified_ = false; |
79 | // Value of the argument (initialized to the default in the constructor). |
80 | T value_; |
81 | }; |
82 | |
83 | template <> |
84 | class Arg<toco::IntList> final { |
85 | public: |
86 | // Provide default_value() to arg list |
87 | std::string default_value() const { return "" ; } |
88 | // Return true if the command line argument was specified on the command line. |
89 | bool specified() const { return specified_; } |
90 | // Bind the parse member function so tensorflow::Flags can call it. |
91 | bool Parse(std::string text); |
92 | |
93 | std::function<bool(std::string)> bind() { |
94 | return std::bind(&Arg::Parse, this, std::placeholders::_1); |
95 | } |
96 | |
97 | const toco::IntList& value() const { return parsed_value_; } |
98 | |
99 | private: |
100 | toco::IntList parsed_value_; |
101 | bool specified_ = false; |
102 | }; |
103 | |
104 | template <> |
105 | class Arg<toco::StringMapList> final { |
106 | public: |
107 | // Provide default_value() to StringMapList |
108 | std::string default_value() const { return "" ; } |
109 | // Return true if the command line argument was specified on the command line. |
110 | bool specified() const { return specified_; } |
111 | // Bind the parse member function so tensorflow::Flags can call it. |
112 | |
113 | bool Parse(std::string text); |
114 | |
115 | std::function<bool(std::string)> bind() { |
116 | return std::bind(&Arg::Parse, this, std::placeholders::_1); |
117 | } |
118 | |
119 | const toco::StringMapList& value() const { return parsed_value_; } |
120 | |
121 | private: |
122 | toco::StringMapList parsed_value_; |
123 | bool specified_ = false; |
124 | }; |
125 | |
126 | // Flags that describe a model. See model_cmdline_flags.cc for details. |
127 | struct ParsedModelFlags { |
128 | Arg<std::string> input_array; |
129 | Arg<std::string> input_arrays; |
130 | Arg<std::string> output_array; |
131 | Arg<std::string> output_arrays; |
132 | Arg<std::string> input_shapes; |
133 | Arg<int> batch_size = Arg<int>(1); |
134 | Arg<float> mean_value = Arg<float>(0.f); |
135 | Arg<std::string> mean_values; |
136 | Arg<float> std_value = Arg<float>(1.f); |
137 | Arg<std::string> std_values; |
138 | Arg<std::string> input_data_type; |
139 | Arg<std::string> input_data_types; |
140 | Arg<bool> variable_batch = Arg<bool>(false); |
141 | Arg<toco::IntList> input_shape; |
142 | Arg<toco::StringMapList> rnn_states; |
143 | Arg<toco::StringMapList> model_checks; |
144 | Arg<bool> change_concat_input_ranges = Arg<bool>(true); |
145 | // Debugging output options. |
146 | // TODO(benoitjacob): these shouldn't be ModelFlags. |
147 | Arg<std::string> graphviz_first_array; |
148 | Arg<std::string> graphviz_last_array; |
149 | Arg<std::string> dump_graphviz; |
150 | Arg<bool> dump_graphviz_video = Arg<bool>(false); |
151 | Arg<std::string> conversion_summary_dir; |
152 | Arg<bool> allow_nonexistent_arrays = Arg<bool>(false); |
153 | Arg<bool> allow_nonascii_arrays = Arg<bool>(false); |
154 | Arg<std::string> ; |
155 | Arg<std::string> model_flags_file; |
156 | }; |
157 | |
158 | // Flags that describe the operation you would like to do (what conversion |
159 | // you want). See toco_cmdline_flags.cc for details. |
160 | struct ParsedTocoFlags { |
161 | Arg<std::string> input_file; |
162 | Arg<std::string> savedmodel_directory; |
163 | Arg<std::string> output_file; |
164 | Arg<std::string> input_format = Arg<std::string>("TENSORFLOW_GRAPHDEF" ); |
165 | Arg<std::string> output_format = Arg<std::string>("TFLITE" ); |
166 | Arg<std::string> savedmodel_tagset; |
167 | // TODO(aselle): command_line_flags doesn't support doubles |
168 | Arg<float> default_ranges_min = Arg<float>(0.); |
169 | Arg<float> default_ranges_max = Arg<float>(0.); |
170 | Arg<float> default_int16_ranges_min = Arg<float>(0.); |
171 | Arg<float> default_int16_ranges_max = Arg<float>(0.); |
172 | Arg<std::string> inference_type; |
173 | Arg<std::string> inference_input_type; |
174 | Arg<bool> drop_fake_quant = Arg<bool>(false); |
175 | Arg<bool> reorder_across_fake_quant = Arg<bool>(false); |
176 | Arg<bool> allow_custom_ops = Arg<bool>(false); |
177 | Arg<bool> allow_dynamic_tensors = Arg<bool>(true); |
178 | Arg<std::string> custom_opdefs; |
179 | Arg<bool> post_training_quantize = Arg<bool>(false); |
180 | Arg<bool> quantize_to_float16 = Arg<bool>(false); |
181 | // Deprecated flags |
182 | Arg<bool> quantize_weights = Arg<bool>(false); |
183 | Arg<std::string> input_type; |
184 | Arg<std::string> input_types; |
185 | Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false); |
186 | Arg<bool> drop_control_dependency = Arg<bool>(false); |
187 | Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false); |
188 | Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false); |
189 | Arg<int64_t> dedupe_array_min_size_bytes = Arg<int64_t>(64); |
190 | Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true); |
191 | // WARNING: Experimental interface, subject to change |
192 | Arg<bool> enable_select_tf_ops = Arg<bool>(false); |
193 | // WARNING: Experimental interface, subject to change |
194 | Arg<bool> force_select_tf_ops = Arg<bool>(false); |
195 | // WARNING: Experimental interface, subject to change |
196 | Arg<bool> unfold_batchmatmul = Arg<bool>(true); |
197 | // WARNING: Experimental interface, subject to change |
198 | Arg<std::string> accumulation_type; |
199 | // WARNING: Experimental interface, subject to change |
200 | Arg<bool> allow_bfloat16; |
201 | }; |
202 | |
203 | } // namespace toco |
204 | #endif // TENSORFLOW_LITE_TOCO_ARGS_H_ |
205 | |