1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/util/command_line_flags.h" |
17 | |
18 | #include <cinttypes> |
19 | #include <cstring> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/tsl/platform/logging.h" |
24 | #include "tensorflow/tsl/platform/str_util.h" |
25 | #include "tensorflow/tsl/platform/stringpiece.h" |
26 | #include "tensorflow/tsl/platform/stringprintf.h" |
27 | |
28 | namespace tsl { |
29 | namespace { |
30 | |
31 | bool ParseStringFlag(StringPiece arg, StringPiece flag, |
32 | const std::function<bool(string)>& hook, |
33 | bool* value_parsing_ok) { |
34 | *value_parsing_ok = true; |
35 | if (absl::ConsumePrefix(&arg, "--" ) && absl::ConsumePrefix(&arg, flag) && |
36 | absl::ConsumePrefix(&arg, "=" )) { |
37 | *value_parsing_ok = hook(string(arg)); |
38 | return true; |
39 | } |
40 | |
41 | return false; |
42 | } |
43 | |
44 | bool ParseInt32Flag(StringPiece arg, StringPiece flag, |
45 | const std::function<bool(int32_t)>& hook, |
46 | bool* value_parsing_ok) { |
47 | *value_parsing_ok = true; |
48 | if (absl::ConsumePrefix(&arg, "--" ) && absl::ConsumePrefix(&arg, flag) && |
49 | absl::ConsumePrefix(&arg, "=" )) { |
50 | char ; |
51 | int32_t parsed_int32; |
52 | if (sscanf(arg.data(), "%d%c" , &parsed_int32, &extra) != 1) { |
53 | LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag |
54 | << "." ; |
55 | *value_parsing_ok = false; |
56 | } else { |
57 | *value_parsing_ok = hook(parsed_int32); |
58 | } |
59 | return true; |
60 | } |
61 | |
62 | return false; |
63 | } |
64 | |
65 | bool ParseInt64Flag(StringPiece arg, StringPiece flag, |
66 | const std::function<bool(int64_t)>& hook, |
67 | bool* value_parsing_ok) { |
68 | *value_parsing_ok = true; |
69 | if (absl::ConsumePrefix(&arg, "--" ) && absl::ConsumePrefix(&arg, flag) && |
70 | absl::ConsumePrefix(&arg, "=" )) { |
71 | char ; |
72 | int64_t parsed_int64; |
73 | if (sscanf(arg.data(), "%" SCNd64 "%c" , &parsed_int64, &extra) != 1) { |
74 | LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag |
75 | << "." ; |
76 | *value_parsing_ok = false; |
77 | } else { |
78 | *value_parsing_ok = hook(parsed_int64); |
79 | } |
80 | return true; |
81 | } |
82 | |
83 | return false; |
84 | } |
85 | |
86 | bool ParseBoolFlag(StringPiece arg, StringPiece flag, |
87 | const std::function<bool(bool)>& hook, |
88 | bool* value_parsing_ok) { |
89 | *value_parsing_ok = true; |
90 | if (absl::ConsumePrefix(&arg, "--" ) && absl::ConsumePrefix(&arg, flag)) { |
91 | if (arg.empty()) { |
92 | *value_parsing_ok = hook(true); |
93 | return true; |
94 | } |
95 | |
96 | if (arg == "=true" ) { |
97 | *value_parsing_ok = hook(true); |
98 | return true; |
99 | } else if (arg == "=false" ) { |
100 | *value_parsing_ok = hook(false); |
101 | return true; |
102 | } else { |
103 | LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag |
104 | << "." ; |
105 | *value_parsing_ok = false; |
106 | return true; |
107 | } |
108 | } |
109 | |
110 | return false; |
111 | } |
112 | |
113 | bool ParseFloatFlag(StringPiece arg, StringPiece flag, |
114 | const std::function<bool(float)>& hook, |
115 | bool* value_parsing_ok) { |
116 | *value_parsing_ok = true; |
117 | if (absl::ConsumePrefix(&arg, "--" ) && absl::ConsumePrefix(&arg, flag) && |
118 | absl::ConsumePrefix(&arg, "=" )) { |
119 | char ; |
120 | float parsed_float; |
121 | if (sscanf(arg.data(), "%f%c" , &parsed_float, &extra) != 1) { |
122 | LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag |
123 | << "." ; |
124 | *value_parsing_ok = false; |
125 | } else { |
126 | *value_parsing_ok = hook(parsed_float); |
127 | } |
128 | return true; |
129 | } |
130 | |
131 | return false; |
132 | } |
133 | |
134 | } // namespace |
135 | |
136 | Flag::Flag(const char* name, int32_t* dst, const string& usage_text, |
137 | bool* dst_updated) |
138 | : name_(name), |
139 | type_(TYPE_INT32), |
140 | int32_hook_([dst, dst_updated](int32_t value) { |
141 | *dst = value; |
142 | if (dst_updated) *dst_updated = true; |
143 | return true; |
144 | }), |
145 | int32_default_for_display_(*dst), |
146 | usage_text_(usage_text) {} |
147 | |
148 | Flag::Flag(const char* name, int64_t* dst, const string& usage_text, |
149 | bool* dst_updated) |
150 | : name_(name), |
151 | type_(TYPE_INT64), |
152 | int64_hook_([dst, dst_updated](int64_t value) { |
153 | *dst = value; |
154 | if (dst_updated) *dst_updated = true; |
155 | return true; |
156 | }), |
157 | int64_default_for_display_(*dst), |
158 | usage_text_(usage_text) {} |
159 | |
160 | Flag::Flag(const char* name, float* dst, const string& usage_text, |
161 | bool* dst_updated) |
162 | : name_(name), |
163 | type_(TYPE_FLOAT), |
164 | float_hook_([dst, dst_updated](float value) { |
165 | *dst = value; |
166 | if (dst_updated) *dst_updated = true; |
167 | return true; |
168 | }), |
169 | float_default_for_display_(*dst), |
170 | usage_text_(usage_text) {} |
171 | |
172 | Flag::Flag(const char* name, bool* dst, const string& usage_text, |
173 | bool* dst_updated) |
174 | : name_(name), |
175 | type_(TYPE_BOOL), |
176 | bool_hook_([dst, dst_updated](bool value) { |
177 | *dst = value; |
178 | if (dst_updated) *dst_updated = true; |
179 | return true; |
180 | }), |
181 | bool_default_for_display_(*dst), |
182 | usage_text_(usage_text) {} |
183 | |
184 | Flag::Flag(const char* name, string* dst, const string& usage_text, |
185 | bool* dst_updated) |
186 | : name_(name), |
187 | type_(TYPE_STRING), |
188 | string_hook_([dst, dst_updated](string value) { |
189 | *dst = std::move(value); |
190 | if (dst_updated) *dst_updated = true; |
191 | return true; |
192 | }), |
193 | string_default_for_display_(*dst), |
194 | usage_text_(usage_text) {} |
195 | |
196 | Flag::Flag(const char* name, std::function<bool(int32_t)> int32_hook, |
197 | int32_t default_value_for_display, const string& usage_text) |
198 | : name_(name), |
199 | type_(TYPE_INT32), |
200 | int32_hook_(std::move(int32_hook)), |
201 | int32_default_for_display_(default_value_for_display), |
202 | usage_text_(usage_text) {} |
203 | |
204 | Flag::Flag(const char* name, std::function<bool(int64_t)> int64_hook, |
205 | int64_t default_value_for_display, const string& usage_text) |
206 | : name_(name), |
207 | type_(TYPE_INT64), |
208 | int64_hook_(std::move(int64_hook)), |
209 | int64_default_for_display_(default_value_for_display), |
210 | usage_text_(usage_text) {} |
211 | |
212 | Flag::Flag(const char* name, std::function<bool(float)> float_hook, |
213 | float default_value_for_display, const string& usage_text) |
214 | : name_(name), |
215 | type_(TYPE_FLOAT), |
216 | float_hook_(std::move(float_hook)), |
217 | float_default_for_display_(default_value_for_display), |
218 | usage_text_(usage_text) {} |
219 | |
220 | Flag::Flag(const char* name, std::function<bool(bool)> bool_hook, |
221 | bool default_value_for_display, const string& usage_text) |
222 | : name_(name), |
223 | type_(TYPE_BOOL), |
224 | bool_hook_(std::move(bool_hook)), |
225 | bool_default_for_display_(default_value_for_display), |
226 | usage_text_(usage_text) {} |
227 | |
228 | Flag::Flag(const char* name, std::function<bool(string)> string_hook, |
229 | string default_value_for_display, const string& usage_text) |
230 | : name_(name), |
231 | type_(TYPE_STRING), |
232 | string_hook_(std::move(string_hook)), |
233 | string_default_for_display_(std::move(default_value_for_display)), |
234 | usage_text_(usage_text) {} |
235 | |
236 | bool Flag::Parse(string arg, bool* value_parsing_ok) const { |
237 | bool result = false; |
238 | if (type_ == TYPE_INT32) { |
239 | result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok); |
240 | } else if (type_ == TYPE_INT64) { |
241 | result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok); |
242 | } else if (type_ == TYPE_BOOL) { |
243 | result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok); |
244 | } else if (type_ == TYPE_STRING) { |
245 | result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok); |
246 | } else if (type_ == TYPE_FLOAT) { |
247 | result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok); |
248 | } |
249 | return result; |
250 | } |
251 | |
252 | /*static*/ bool Flags::Parse(int* argc, char** argv, |
253 | const std::vector<Flag>& flag_list) { |
254 | bool result = true; |
255 | std::vector<char*> unknown_flags; |
256 | for (int i = 1; i < *argc; ++i) { |
257 | if (string(argv[i]) == "--" ) { |
258 | while (i < *argc) { |
259 | unknown_flags.push_back(argv[i]); |
260 | ++i; |
261 | } |
262 | break; |
263 | } |
264 | |
265 | bool was_found = false; |
266 | for (const Flag& flag : flag_list) { |
267 | bool value_parsing_ok; |
268 | was_found = flag.Parse(argv[i], &value_parsing_ok); |
269 | if (!value_parsing_ok) { |
270 | result = false; |
271 | } |
272 | if (was_found) { |
273 | break; |
274 | } |
275 | } |
276 | if (!was_found) { |
277 | unknown_flags.push_back(argv[i]); |
278 | } |
279 | } |
280 | // Passthrough any extra flags. |
281 | int dst = 1; // Skip argv[0] |
282 | for (char* f : unknown_flags) { |
283 | argv[dst++] = f; |
284 | } |
285 | argv[dst++] = nullptr; |
286 | *argc = unknown_flags.size() + 1; |
287 | return result && (*argc < 2 || strcmp(argv[1], "--help" ) != 0); |
288 | } |
289 | |
290 | /*static*/ string Flags::Usage(const string& cmdline, |
291 | const std::vector<Flag>& flag_list) { |
292 | string usage_text; |
293 | if (!flag_list.empty()) { |
294 | strings::Appendf(&usage_text, "usage: %s\nFlags:\n" , cmdline.c_str()); |
295 | } else { |
296 | strings::Appendf(&usage_text, "usage: %s\n" , cmdline.c_str()); |
297 | } |
298 | for (const Flag& flag : flag_list) { |
299 | const char* type_name = "" ; |
300 | string flag_string; |
301 | if (flag.type_ == Flag::TYPE_INT32) { |
302 | type_name = "int32" ; |
303 | flag_string = strings::Printf("--%s=%d" , flag.name_.c_str(), |
304 | flag.int32_default_for_display_); |
305 | } else if (flag.type_ == Flag::TYPE_INT64) { |
306 | type_name = "int64" ; |
307 | flag_string = strings::Printf( |
308 | "--%s=%lld" , flag.name_.c_str(), |
309 | static_cast<long long>(flag.int64_default_for_display_)); |
310 | } else if (flag.type_ == Flag::TYPE_BOOL) { |
311 | type_name = "bool" ; |
312 | flag_string = |
313 | strings::Printf("--%s=%s" , flag.name_.c_str(), |
314 | flag.bool_default_for_display_ ? "true" : "false" ); |
315 | } else if (flag.type_ == Flag::TYPE_STRING) { |
316 | type_name = "string" ; |
317 | flag_string = strings::Printf("--%s=\"%s\"" , flag.name_.c_str(), |
318 | flag.string_default_for_display_.c_str()); |
319 | } else if (flag.type_ == Flag::TYPE_FLOAT) { |
320 | type_name = "float" ; |
321 | flag_string = strings::Printf("--%s=%f" , flag.name_.c_str(), |
322 | flag.float_default_for_display_); |
323 | } |
324 | strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n" , flag_string.c_str(), |
325 | type_name, flag.usage_text_.c_str()); |
326 | } |
327 | return usage_text; |
328 | } |
329 | |
330 | } // namespace tsl |
331 | |