1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tsl/util/command_line_flags.h"
17
18#include <cinttypes>
19#include <cstring>
20#include <string>
21#include <vector>
22
23#include "tensorflow/tsl/platform/logging.h"
24#include "tensorflow/tsl/platform/str_util.h"
25#include "tensorflow/tsl/platform/stringpiece.h"
26#include "tensorflow/tsl/platform/stringprintf.h"
27
28namespace tsl {
29namespace {
30
31bool ParseStringFlag(StringPiece arg, StringPiece flag,
32 const std::function<bool(string)>& hook,
33 bool* value_parsing_ok) {
34 *value_parsing_ok = true;
35 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
36 absl::ConsumePrefix(&arg, "=")) {
37 *value_parsing_ok = hook(string(arg));
38 return true;
39 }
40
41 return false;
42}
43
44bool ParseInt32Flag(StringPiece arg, StringPiece flag,
45 const std::function<bool(int32_t)>& hook,
46 bool* value_parsing_ok) {
47 *value_parsing_ok = true;
48 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
49 absl::ConsumePrefix(&arg, "=")) {
50 char extra;
51 int32_t parsed_int32;
52 if (sscanf(arg.data(), "%d%c", &parsed_int32, &extra) != 1) {
53 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
54 << ".";
55 *value_parsing_ok = false;
56 } else {
57 *value_parsing_ok = hook(parsed_int32);
58 }
59 return true;
60 }
61
62 return false;
63}
64
65bool ParseInt64Flag(StringPiece arg, StringPiece flag,
66 const std::function<bool(int64_t)>& hook,
67 bool* value_parsing_ok) {
68 *value_parsing_ok = true;
69 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
70 absl::ConsumePrefix(&arg, "=")) {
71 char extra;
72 int64_t parsed_int64;
73 if (sscanf(arg.data(), "%" SCNd64 "%c", &parsed_int64, &extra) != 1) {
74 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
75 << ".";
76 *value_parsing_ok = false;
77 } else {
78 *value_parsing_ok = hook(parsed_int64);
79 }
80 return true;
81 }
82
83 return false;
84}
85
86bool ParseBoolFlag(StringPiece arg, StringPiece flag,
87 const std::function<bool(bool)>& hook,
88 bool* value_parsing_ok) {
89 *value_parsing_ok = true;
90 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag)) {
91 if (arg.empty()) {
92 *value_parsing_ok = hook(true);
93 return true;
94 }
95
96 if (arg == "=true") {
97 *value_parsing_ok = hook(true);
98 return true;
99 } else if (arg == "=false") {
100 *value_parsing_ok = hook(false);
101 return true;
102 } else {
103 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
104 << ".";
105 *value_parsing_ok = false;
106 return true;
107 }
108 }
109
110 return false;
111}
112
113bool ParseFloatFlag(StringPiece arg, StringPiece flag,
114 const std::function<bool(float)>& hook,
115 bool* value_parsing_ok) {
116 *value_parsing_ok = true;
117 if (absl::ConsumePrefix(&arg, "--") && absl::ConsumePrefix(&arg, flag) &&
118 absl::ConsumePrefix(&arg, "=")) {
119 char extra;
120 float parsed_float;
121 if (sscanf(arg.data(), "%f%c", &parsed_float, &extra) != 1) {
122 LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
123 << ".";
124 *value_parsing_ok = false;
125 } else {
126 *value_parsing_ok = hook(parsed_float);
127 }
128 return true;
129 }
130
131 return false;
132}
133
134} // namespace
135
136Flag::Flag(const char* name, int32_t* dst, const string& usage_text,
137 bool* dst_updated)
138 : name_(name),
139 type_(TYPE_INT32),
140 int32_hook_([dst, dst_updated](int32_t value) {
141 *dst = value;
142 if (dst_updated) *dst_updated = true;
143 return true;
144 }),
145 int32_default_for_display_(*dst),
146 usage_text_(usage_text) {}
147
148Flag::Flag(const char* name, int64_t* dst, const string& usage_text,
149 bool* dst_updated)
150 : name_(name),
151 type_(TYPE_INT64),
152 int64_hook_([dst, dst_updated](int64_t value) {
153 *dst = value;
154 if (dst_updated) *dst_updated = true;
155 return true;
156 }),
157 int64_default_for_display_(*dst),
158 usage_text_(usage_text) {}
159
160Flag::Flag(const char* name, float* dst, const string& usage_text,
161 bool* dst_updated)
162 : name_(name),
163 type_(TYPE_FLOAT),
164 float_hook_([dst, dst_updated](float value) {
165 *dst = value;
166 if (dst_updated) *dst_updated = true;
167 return true;
168 }),
169 float_default_for_display_(*dst),
170 usage_text_(usage_text) {}
171
172Flag::Flag(const char* name, bool* dst, const string& usage_text,
173 bool* dst_updated)
174 : name_(name),
175 type_(TYPE_BOOL),
176 bool_hook_([dst, dst_updated](bool value) {
177 *dst = value;
178 if (dst_updated) *dst_updated = true;
179 return true;
180 }),
181 bool_default_for_display_(*dst),
182 usage_text_(usage_text) {}
183
184Flag::Flag(const char* name, string* dst, const string& usage_text,
185 bool* dst_updated)
186 : name_(name),
187 type_(TYPE_STRING),
188 string_hook_([dst, dst_updated](string value) {
189 *dst = std::move(value);
190 if (dst_updated) *dst_updated = true;
191 return true;
192 }),
193 string_default_for_display_(*dst),
194 usage_text_(usage_text) {}
195
196Flag::Flag(const char* name, std::function<bool(int32_t)> int32_hook,
197 int32_t default_value_for_display, const string& usage_text)
198 : name_(name),
199 type_(TYPE_INT32),
200 int32_hook_(std::move(int32_hook)),
201 int32_default_for_display_(default_value_for_display),
202 usage_text_(usage_text) {}
203
204Flag::Flag(const char* name, std::function<bool(int64_t)> int64_hook,
205 int64_t default_value_for_display, const string& usage_text)
206 : name_(name),
207 type_(TYPE_INT64),
208 int64_hook_(std::move(int64_hook)),
209 int64_default_for_display_(default_value_for_display),
210 usage_text_(usage_text) {}
211
212Flag::Flag(const char* name, std::function<bool(float)> float_hook,
213 float default_value_for_display, const string& usage_text)
214 : name_(name),
215 type_(TYPE_FLOAT),
216 float_hook_(std::move(float_hook)),
217 float_default_for_display_(default_value_for_display),
218 usage_text_(usage_text) {}
219
220Flag::Flag(const char* name, std::function<bool(bool)> bool_hook,
221 bool default_value_for_display, const string& usage_text)
222 : name_(name),
223 type_(TYPE_BOOL),
224 bool_hook_(std::move(bool_hook)),
225 bool_default_for_display_(default_value_for_display),
226 usage_text_(usage_text) {}
227
228Flag::Flag(const char* name, std::function<bool(string)> string_hook,
229 string default_value_for_display, const string& usage_text)
230 : name_(name),
231 type_(TYPE_STRING),
232 string_hook_(std::move(string_hook)),
233 string_default_for_display_(std::move(default_value_for_display)),
234 usage_text_(usage_text) {}
235
236bool Flag::Parse(string arg, bool* value_parsing_ok) const {
237 bool result = false;
238 if (type_ == TYPE_INT32) {
239 result = ParseInt32Flag(arg, name_, int32_hook_, value_parsing_ok);
240 } else if (type_ == TYPE_INT64) {
241 result = ParseInt64Flag(arg, name_, int64_hook_, value_parsing_ok);
242 } else if (type_ == TYPE_BOOL) {
243 result = ParseBoolFlag(arg, name_, bool_hook_, value_parsing_ok);
244 } else if (type_ == TYPE_STRING) {
245 result = ParseStringFlag(arg, name_, string_hook_, value_parsing_ok);
246 } else if (type_ == TYPE_FLOAT) {
247 result = ParseFloatFlag(arg, name_, float_hook_, value_parsing_ok);
248 }
249 return result;
250}
251
252/*static*/ bool Flags::Parse(int* argc, char** argv,
253 const std::vector<Flag>& flag_list) {
254 bool result = true;
255 std::vector<char*> unknown_flags;
256 for (int i = 1; i < *argc; ++i) {
257 if (string(argv[i]) == "--") {
258 while (i < *argc) {
259 unknown_flags.push_back(argv[i]);
260 ++i;
261 }
262 break;
263 }
264
265 bool was_found = false;
266 for (const Flag& flag : flag_list) {
267 bool value_parsing_ok;
268 was_found = flag.Parse(argv[i], &value_parsing_ok);
269 if (!value_parsing_ok) {
270 result = false;
271 }
272 if (was_found) {
273 break;
274 }
275 }
276 if (!was_found) {
277 unknown_flags.push_back(argv[i]);
278 }
279 }
280 // Passthrough any extra flags.
281 int dst = 1; // Skip argv[0]
282 for (char* f : unknown_flags) {
283 argv[dst++] = f;
284 }
285 argv[dst++] = nullptr;
286 *argc = unknown_flags.size() + 1;
287 return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
288}
289
290/*static*/ string Flags::Usage(const string& cmdline,
291 const std::vector<Flag>& flag_list) {
292 string usage_text;
293 if (!flag_list.empty()) {
294 strings::Appendf(&usage_text, "usage: %s\nFlags:\n", cmdline.c_str());
295 } else {
296 strings::Appendf(&usage_text, "usage: %s\n", cmdline.c_str());
297 }
298 for (const Flag& flag : flag_list) {
299 const char* type_name = "";
300 string flag_string;
301 if (flag.type_ == Flag::TYPE_INT32) {
302 type_name = "int32";
303 flag_string = strings::Printf("--%s=%d", flag.name_.c_str(),
304 flag.int32_default_for_display_);
305 } else if (flag.type_ == Flag::TYPE_INT64) {
306 type_name = "int64";
307 flag_string = strings::Printf(
308 "--%s=%lld", flag.name_.c_str(),
309 static_cast<long long>(flag.int64_default_for_display_));
310 } else if (flag.type_ == Flag::TYPE_BOOL) {
311 type_name = "bool";
312 flag_string =
313 strings::Printf("--%s=%s", flag.name_.c_str(),
314 flag.bool_default_for_display_ ? "true" : "false");
315 } else if (flag.type_ == Flag::TYPE_STRING) {
316 type_name = "string";
317 flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(),
318 flag.string_default_for_display_.c_str());
319 } else if (flag.type_ == Flag::TYPE_FLOAT) {
320 type_name = "float";
321 flag_string = strings::Printf("--%s=%f", flag.name_.c_str(),
322 flag.float_default_for_display_);
323 }
324 strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(),
325 type_name, flag.usage_text_.c_str());
326 }
327 return usage_text;
328}
329
330} // namespace tsl
331