1#include <c10/macros/Macros.h>
2#include <c10/util/Flags.h>
3
4#include <cstdlib>
5#include <iostream>
6#include <sstream>
7#include <string>
8
9#ifndef C10_USE_GFLAGS
10
11namespace c10 {
12
13using std::string;
14
15C10_DEFINE_REGISTRY(C10FlagsRegistry, C10FlagParser, const string&);
16
17namespace {
18static bool gCommandLineFlagsParsed = false;
19// Since flags is going to be loaded before logging, we would
20// need to have a stringstream to hold the messages instead of directly
21// using caffe logging.
22std::stringstream& GlobalInitStream() {
23 static std::stringstream ss;
24 return ss;
25}
26static const char* gUsageMessage = "(Usage message not set.)";
27} // namespace
28
29C10_EXPORT void SetUsageMessage(const string& str) {
30 static string usage_message_safe_copy = str;
31 gUsageMessage = usage_message_safe_copy.c_str();
32}
33
34C10_EXPORT const char* UsageMessage() {
35 return gUsageMessage;
36}
37
38C10_EXPORT bool ParseCommandLineFlags(int* pargc, char*** pargv) {
39 if (*pargc == 0)
40 return true;
41 char** argv = *pargv;
42 bool success = true;
43 GlobalInitStream() << "Parsing commandline arguments for c10." << std::endl;
44 // write_head is the location we write the unused arguments to.
45 int write_head = 1;
46 for (int i = 1; i < *pargc; ++i) {
47 string arg(argv[i]);
48
49 if (arg.find("--help") != string::npos) {
50 // Print the help message, and quit.
51 std::cout << UsageMessage() << std::endl;
52 std::cout << "Arguments: " << std::endl;
53 for (const auto& help_msg : C10FlagsRegistry()->HelpMessage()) {
54 std::cout << " " << help_msg.first << ": " << help_msg.second
55 << std::endl;
56 }
57 exit(0);
58 }
59 // If the arg does not start with "--", we will ignore it.
60 if (arg[0] != '-' || arg[1] != '-') {
61 GlobalInitStream()
62 << "C10 flag: commandline argument does not match --name=var "
63 "or --name format: "
64 << arg << ". Ignoring this argument." << std::endl;
65 argv[write_head++] = argv[i];
66 continue;
67 }
68
69 string key;
70 string value;
71 size_t prefix_idx = arg.find('=');
72 if (prefix_idx == string::npos) {
73 // If there is no equality char in the arg, it means that the
74 // arg is specified in the next argument.
75 key = arg.substr(2, arg.size() - 2);
76 ++i;
77 if (i == *pargc) {
78 GlobalInitStream()
79 << "C10 flag: reached the last commandline argument, but "
80 "I am expecting a value for "
81 << arg;
82 success = false;
83 break;
84 }
85 value = string(argv[i]);
86 } else {
87 // If there is an equality character, we will basically use the value
88 // after the "=".
89 key = arg.substr(2, prefix_idx - 2);
90 value = arg.substr(prefix_idx + 1, string::npos);
91 }
92 // If the flag is not registered, we will ignore it.
93 if (!C10FlagsRegistry()->Has(key)) {
94 GlobalInitStream() << "C10 flag: unrecognized commandline argument: "
95 << arg << std::endl;
96 success = false;
97 break;
98 }
99 std::unique_ptr<C10FlagParser> parser(
100 C10FlagsRegistry()->Create(key, value));
101 if (!parser->success()) {
102 GlobalInitStream() << "C10 flag: illegal argument: " << arg << std::endl;
103 success = false;
104 break;
105 }
106 }
107 *pargc = write_head;
108 gCommandLineFlagsParsed = true;
109 // TODO: when we fail commandline flag parsing, shall we continue, or
110 // shall we just quit loudly? Right now we carry on the computation, but
111 // since there are failures in parsing, it is very likely that some
112 // downstream things will break, in which case it makes sense to quit loud
113 // and early.
114 if (!success) {
115 std::cerr << GlobalInitStream().str();
116 }
117 // Clear the global init stream.
118 GlobalInitStream().str(std::string());
119 return success;
120}
121
122C10_EXPORT bool CommandLineFlagsHasBeenParsed() {
123 return gCommandLineFlagsParsed;
124}
125
126template <>
127C10_EXPORT bool C10FlagParser::Parse<string>(
128 const string& content,
129 string* value) {
130 *value = content;
131 return true;
132}
133
134template <>
135C10_EXPORT bool C10FlagParser::Parse<int>(const string& content, int* value) {
136 try {
137 *value = std::atoi(content.c_str());
138 return true;
139 } catch (...) {
140 GlobalInitStream() << "C10 flag error: Cannot convert argument to int: "
141 << content << std::endl;
142 return false;
143 }
144}
145
146template <>
147C10_EXPORT bool C10FlagParser::Parse<int64_t>(
148 const string& content,
149 int64_t* value) {
150 try {
151 static_assert(sizeof(long long) == sizeof(int64_t));
152#ifdef __ANDROID__
153 // Android does not have std::atoll.
154 *value = atoll(content.c_str());
155#else
156 *value = std::atoll(content.c_str());
157#endif
158 return true;
159 } catch (...) {
160 GlobalInitStream() << "C10 flag error: Cannot convert argument to int: "
161 << content << std::endl;
162 return false;
163 }
164}
165
166template <>
167C10_EXPORT bool C10FlagParser::Parse<double>(
168 const string& content,
169 double* value) {
170 try {
171 *value = std::atof(content.c_str());
172 return true;
173 } catch (...) {
174 GlobalInitStream() << "C10 flag error: Cannot convert argument to double: "
175 << content << std::endl;
176 return false;
177 }
178}
179
180template <>
181C10_EXPORT bool C10FlagParser::Parse<bool>(const string& content, bool* value) {
182 if (content == "false" || content == "False" || content == "FALSE" ||
183 content == "0") {
184 *value = false;
185 return true;
186 } else if (
187 content == "true" || content == "True" || content == "TRUE" ||
188 content == "1") {
189 *value = true;
190 return true;
191 } else {
192 GlobalInitStream()
193 << "C10 flag error: Cannot convert argument to bool: " << content
194 << std::endl
195 << "Note that if you are passing in a bool flag, you need to "
196 "explicitly specify it, like --arg=True or --arg True. Otherwise, "
197 "the next argument may be inadvertently used as the argument, "
198 "causing the above error."
199 << std::endl;
200 return false;
201 }
202}
203
204} // namespace c10
205
206#endif // C10_USE_GFLAGS
207