1 | #include <c10/macros/Macros.h> |
2 | #include <c10/util/Flags.h> |
3 | |
4 | #include <cstdlib> |
5 | #include <iostream> |
6 | #include <sstream> |
7 | #include <string> |
8 | |
9 | #ifndef C10_USE_GFLAGS |
10 | |
11 | namespace c10 { |
12 | |
13 | using std::string; |
14 | |
15 | C10_DEFINE_REGISTRY(C10FlagsRegistry, C10FlagParser, const string&); |
16 | |
17 | namespace { |
18 | static bool gCommandLineFlagsParsed = false; |
19 | // Since flags is going to be loaded before logging, we would |
20 | // need to have a stringstream to hold the messages instead of directly |
21 | // using caffe logging. |
22 | std::stringstream& GlobalInitStream() { |
23 | static std::stringstream ss; |
24 | return ss; |
25 | } |
26 | static const char* gUsageMessage = "(Usage message not set.)" ; |
27 | } // namespace |
28 | |
29 | C10_EXPORT void SetUsageMessage(const string& str) { |
30 | static string usage_message_safe_copy = str; |
31 | gUsageMessage = usage_message_safe_copy.c_str(); |
32 | } |
33 | |
34 | C10_EXPORT const char* UsageMessage() { |
35 | return gUsageMessage; |
36 | } |
37 | |
38 | C10_EXPORT bool ParseCommandLineFlags(int* pargc, char*** pargv) { |
39 | if (*pargc == 0) |
40 | return true; |
41 | char** argv = *pargv; |
42 | bool success = true; |
43 | GlobalInitStream() << "Parsing commandline arguments for c10." << std::endl; |
44 | // write_head is the location we write the unused arguments to. |
45 | int write_head = 1; |
46 | for (int i = 1; i < *pargc; ++i) { |
47 | string arg(argv[i]); |
48 | |
49 | if (arg.find("--help" ) != string::npos) { |
50 | // Print the help message, and quit. |
51 | std::cout << UsageMessage() << std::endl; |
52 | std::cout << "Arguments: " << std::endl; |
53 | for (const auto& help_msg : C10FlagsRegistry()->HelpMessage()) { |
54 | std::cout << " " << help_msg.first << ": " << help_msg.second |
55 | << std::endl; |
56 | } |
57 | exit(0); |
58 | } |
59 | // If the arg does not start with "--", we will ignore it. |
60 | if (arg[0] != '-' || arg[1] != '-') { |
61 | GlobalInitStream() |
62 | << "C10 flag: commandline argument does not match --name=var " |
63 | "or --name format: " |
64 | << arg << ". Ignoring this argument." << std::endl; |
65 | argv[write_head++] = argv[i]; |
66 | continue; |
67 | } |
68 | |
69 | string key; |
70 | string value; |
71 | size_t prefix_idx = arg.find('='); |
72 | if (prefix_idx == string::npos) { |
73 | // If there is no equality char in the arg, it means that the |
74 | // arg is specified in the next argument. |
75 | key = arg.substr(2, arg.size() - 2); |
76 | ++i; |
77 | if (i == *pargc) { |
78 | GlobalInitStream() |
79 | << "C10 flag: reached the last commandline argument, but " |
80 | "I am expecting a value for " |
81 | << arg; |
82 | success = false; |
83 | break; |
84 | } |
85 | value = string(argv[i]); |
86 | } else { |
87 | // If there is an equality character, we will basically use the value |
88 | // after the "=". |
89 | key = arg.substr(2, prefix_idx - 2); |
90 | value = arg.substr(prefix_idx + 1, string::npos); |
91 | } |
92 | // If the flag is not registered, we will ignore it. |
93 | if (!C10FlagsRegistry()->Has(key)) { |
94 | GlobalInitStream() << "C10 flag: unrecognized commandline argument: " |
95 | << arg << std::endl; |
96 | success = false; |
97 | break; |
98 | } |
99 | std::unique_ptr<C10FlagParser> parser( |
100 | C10FlagsRegistry()->Create(key, value)); |
101 | if (!parser->success()) { |
102 | GlobalInitStream() << "C10 flag: illegal argument: " << arg << std::endl; |
103 | success = false; |
104 | break; |
105 | } |
106 | } |
107 | *pargc = write_head; |
108 | gCommandLineFlagsParsed = true; |
109 | // TODO: when we fail commandline flag parsing, shall we continue, or |
110 | // shall we just quit loudly? Right now we carry on the computation, but |
111 | // since there are failures in parsing, it is very likely that some |
112 | // downstream things will break, in which case it makes sense to quit loud |
113 | // and early. |
114 | if (!success) { |
115 | std::cerr << GlobalInitStream().str(); |
116 | } |
117 | // Clear the global init stream. |
118 | GlobalInitStream().str(std::string()); |
119 | return success; |
120 | } |
121 | |
122 | C10_EXPORT bool CommandLineFlagsHasBeenParsed() { |
123 | return gCommandLineFlagsParsed; |
124 | } |
125 | |
126 | template <> |
127 | C10_EXPORT bool C10FlagParser::Parse<string>( |
128 | const string& content, |
129 | string* value) { |
130 | *value = content; |
131 | return true; |
132 | } |
133 | |
134 | template <> |
135 | C10_EXPORT bool C10FlagParser::Parse<int>(const string& content, int* value) { |
136 | try { |
137 | *value = std::atoi(content.c_str()); |
138 | return true; |
139 | } catch (...) { |
140 | GlobalInitStream() << "C10 flag error: Cannot convert argument to int: " |
141 | << content << std::endl; |
142 | return false; |
143 | } |
144 | } |
145 | |
146 | template <> |
147 | C10_EXPORT bool C10FlagParser::Parse<int64_t>( |
148 | const string& content, |
149 | int64_t* value) { |
150 | try { |
151 | static_assert(sizeof(long long) == sizeof(int64_t)); |
152 | #ifdef __ANDROID__ |
153 | // Android does not have std::atoll. |
154 | *value = atoll(content.c_str()); |
155 | #else |
156 | *value = std::atoll(content.c_str()); |
157 | #endif |
158 | return true; |
159 | } catch (...) { |
160 | GlobalInitStream() << "C10 flag error: Cannot convert argument to int: " |
161 | << content << std::endl; |
162 | return false; |
163 | } |
164 | } |
165 | |
166 | template <> |
167 | C10_EXPORT bool C10FlagParser::Parse<double>( |
168 | const string& content, |
169 | double* value) { |
170 | try { |
171 | *value = std::atof(content.c_str()); |
172 | return true; |
173 | } catch (...) { |
174 | GlobalInitStream() << "C10 flag error: Cannot convert argument to double: " |
175 | << content << std::endl; |
176 | return false; |
177 | } |
178 | } |
179 | |
180 | template <> |
181 | C10_EXPORT bool C10FlagParser::Parse<bool>(const string& content, bool* value) { |
182 | if (content == "false" || content == "False" || content == "FALSE" || |
183 | content == "0" ) { |
184 | *value = false; |
185 | return true; |
186 | } else if ( |
187 | content == "true" || content == "True" || content == "TRUE" || |
188 | content == "1" ) { |
189 | *value = true; |
190 | return true; |
191 | } else { |
192 | GlobalInitStream() |
193 | << "C10 flag error: Cannot convert argument to bool: " << content |
194 | << std::endl |
195 | << "Note that if you are passing in a bool flag, you need to " |
196 | "explicitly specify it, like --arg=True or --arg True. Otherwise, " |
197 | "the next argument may be inadvertently used as the argument, " |
198 | "causing the above error." |
199 | << std::endl; |
200 | return false; |
201 | } |
202 | } |
203 | |
204 | } // namespace c10 |
205 | |
206 | #endif // C10_USE_GFLAGS |
207 | |