1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ |
17 | #define TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ |
18 | |
19 | #include <functional> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/tsl/platform/types.h" |
24 | |
25 | namespace tsl { |
26 | |
27 | // N.B. This library is for INTERNAL use only. |
28 | // |
29 | // This is a simple command-line argument parsing module to help us handle |
30 | // parameters for C++ binaries. The recommended way of using it is with local |
31 | // variables and an initializer list of Flag objects, for example: |
32 | // |
33 | // int some_int = 10; |
34 | // bool some_switch = false; |
35 | // string some_name = "something"; |
36 | // std::vector<tsl::Flag> flag_list = { |
37 | // Flag("some_int", &some_int, "an integer that affects X"), |
38 | // Flag("some_switch", &some_switch, "a bool that affects Y"), |
39 | // Flag("some_name", &some_name, "a string that affects Z") |
40 | // }; |
41 | // // Get usage message before ParseFlags() to capture default values. |
42 | // string usage = Flag::Usage(argv[0], flag_list); |
43 | // bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); |
44 | // |
45 | // tsl::port::InitMain(usage.c_str(), &argc, &argv); |
46 | // if (argc != 1 || !parsed_values_ok) { |
47 | // ...output usage and error message... |
48 | // } |
49 | // |
50 | // The argc and argv values are adjusted by the Parse function so all that |
51 | // remains is the program name (at argv[0]) and any unknown arguments fill the |
52 | // rest of the array. This means you can check for flags that weren't understood |
53 | // by seeing if argv is greater than 1. |
54 | // The result indicates if there were any errors parsing the values that were |
55 | // passed to the command-line switches. For example, --some_int=foo would return |
56 | // false because the argument is expected to be an integer. |
57 | // |
58 | // NOTE: Unlike gflags-style libraries, this library is intended to be |
59 | // used in the `main()` function of your binary. It does not handle |
60 | // flag definitions that are scattered around the source code. |
61 | |
62 | // A description of a single command line flag, holding its name, type, usage |
63 | // text, and a pointer to the corresponding variable. |
64 | class Flag { |
65 | public: |
66 | Flag(const char* name, int32* dst, const string& usage_text, |
67 | bool* dst_updated = nullptr); |
68 | Flag(const char* name, int64_t* dst, const string& usage_text, |
69 | bool* dst_updated = nullptr); |
70 | Flag(const char* name, bool* dst, const string& usage_text, |
71 | bool* dst_updated = nullptr); |
72 | Flag(const char* name, string* dst, const string& usage_text, |
73 | bool* dst_updated = nullptr); |
74 | Flag(const char* name, float* dst, const string& usage_text, |
75 | bool* dst_updated = nullptr); |
76 | |
77 | // These constructors invoke a hook on a match instead of writing to a |
78 | // specific memory location. The hook may return false to signal a malformed |
79 | // or illegal value, which will then fail the command line parse. |
80 | // |
81 | // "default_value_for_display" is shown as the default value of this flag in |
82 | // Flags::Usage(). |
83 | Flag(const char* name, std::function<bool(int32_t)> int32_hook, |
84 | int32_t default_value_for_display, const string& usage_text); |
85 | Flag(const char* name, std::function<bool(int64_t)> int64_hook, |
86 | int64_t default_value_for_display, const string& usage_text); |
87 | Flag(const char* name, std::function<bool(float)> float_hook, |
88 | float default_value_for_display, const string& usage_text); |
89 | Flag(const char* name, std::function<bool(bool)> bool_hook, |
90 | bool default_value_for_display, const string& usage_text); |
91 | Flag(const char* name, std::function<bool(string)> string_hook, |
92 | string default_value_for_display, const string& usage_text); |
93 | |
94 | bool is_default_initialized() const { return default_initialized_; } |
95 | |
96 | private: |
97 | friend class Flags; |
98 | |
99 | bool Parse(string arg, bool* value_parsing_ok) const; |
100 | |
101 | string name_; |
102 | enum { |
103 | TYPE_INT32, |
104 | TYPE_INT64, |
105 | TYPE_BOOL, |
106 | TYPE_STRING, |
107 | TYPE_FLOAT, |
108 | } type_; |
109 | |
110 | std::function<bool(int32_t)> int32_hook_; |
111 | int32 int32_default_for_display_; |
112 | |
113 | std::function<bool(int64_t)> int64_hook_; |
114 | int64_t int64_default_for_display_; |
115 | |
116 | std::function<bool(float)> float_hook_; |
117 | float float_default_for_display_; |
118 | |
119 | std::function<bool(bool)> bool_hook_; |
120 | bool bool_default_for_display_; |
121 | |
122 | std::function<bool(string)> string_hook_; |
123 | string string_default_for_display_; |
124 | |
125 | string usage_text_; |
126 | bool default_initialized_ = true; |
127 | }; |
128 | |
129 | class Flags { |
130 | public: |
131 | // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag |
132 | // instances matching flags in flaglist[]. Update the variables associated |
133 | // with matching flags, and remove the matching arguments from (*argc, argv). |
134 | // Return true iff all recognized flag values were parsed correctly, and the |
135 | // first remaining argument is not "--help". |
136 | static bool Parse(int* argc, char** argv, const std::vector<Flag>& flag_list); |
137 | |
138 | // Return a usage message with command line cmdline, and the |
139 | // usage_text strings in flag_list[]. |
140 | static string Usage(const string& cmdline, |
141 | const std::vector<Flag>& flag_list); |
142 | }; |
143 | |
144 | } // namespace tsl |
145 | |
146 | #endif // TENSORFLOW_TSL_UTIL_COMMAND_LINE_FLAGS_H_ |
147 | |