1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef UTILS_PARSER_HPP |
18 | #define UTILS_PARSER_HPP |
19 | |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | |
23 | #include <sstream> |
24 | #include <string> |
25 | #include <vector> |
26 | |
27 | #include "oneapi/dnnl/dnnl_types.h" |
28 | |
29 | #include "dnn_types.hpp" |
30 | #include "dnnl_debug.hpp" |
31 | #include "tests/test_thread.hpp" |
32 | #include "utils/dims.hpp" |
33 | |
34 | namespace parser { |
35 | |
36 | extern bool last_parsed_is_problem; |
37 | extern const size_t eol; |
38 | extern std::stringstream help_ss; |
39 | |
40 | namespace parser_utils { |
41 | std::string get_pattern(const std::string &option_name, bool with_args = true); |
42 | void add_option_to_help(const std::string &option, |
43 | const std::string &help_message, bool with_args = true); |
44 | } // namespace parser_utils |
45 | |
46 | template <typename T, typename F> |
47 | static bool parse_vector_str(T &vec, const T &def, F process_func, |
48 | const std::string &str, char delimeter = ',') { |
49 | const std::string s = str; |
50 | if (s.empty()) return vec = def, true; |
51 | |
52 | vec.clear(); |
53 | for (size_t pos_st = 0, pos_en = s.find_first_of(delimeter, pos_st); true; |
54 | pos_st = pos_en + 1, pos_en = s.find_first_of(delimeter, pos_st)) { |
55 | vec.push_back(process_func(s.substr(pos_st, pos_en - pos_st).c_str())); |
56 | if (pos_en == eol) break; |
57 | } |
58 | return true; |
59 | } |
60 | |
61 | template <typename T, typename F> |
62 | static bool parse_multivector_str(std::vector<T> &vec, |
63 | const std::vector<T> &def, F process_func, const std::string &str, |
64 | char vector_delim = ',', char element_delim = ':') { |
65 | auto process_subword = [&](const char *word) { |
66 | T v, empty_def_v; // defualt value is not expected to be set here |
67 | // parse vector elements separated by @p element_delim |
68 | parse_vector_str(v, empty_def_v, process_func, word, element_delim); |
69 | return v; |
70 | }; |
71 | |
72 | // parse full vector separated by @p vector_delim |
73 | return parse_vector_str(vec, def, process_subword, str, vector_delim); |
74 | } |
75 | |
76 | template <typename T, typename F> |
77 | static bool parse_vector_option(T &vec, const T &def, F process_func, |
78 | const char *str, const std::string &option_name, |
79 | const std::string &help_message = "" ) { |
80 | parser_utils::add_option_to_help(option_name, help_message); |
81 | const std::string pattern = parser_utils::get_pattern(option_name); |
82 | if (pattern.find(str, 0, pattern.size()) == eol) return false; |
83 | return parse_vector_str(vec, def, process_func, str + pattern.size()); |
84 | } |
85 | |
86 | template <typename T, typename F> |
87 | static bool parse_multivector_option(std::vector<T> &vec, |
88 | const std::vector<T> &def, F process_func, const char *str, |
89 | const std::string &option_name, const std::string &help_message = "" , |
90 | char vector_delim = ',', char element_delim = ':') { |
91 | parser_utils::add_option_to_help(option_name, help_message); |
92 | const std::string pattern = parser_utils::get_pattern(option_name); |
93 | if (pattern.find(str, 0, pattern.size()) == eol) return false; |
94 | return parse_multivector_str(vec, def, process_func, str + pattern.size(), |
95 | vector_delim, element_delim); |
96 | } |
97 | |
98 | template <typename T, typename F> |
99 | static bool parse_single_value_option(T &val, const T &def_val, F process_func, |
100 | const char *str, const std::string &option_name, |
101 | const std::string &help_message = "" ) { |
102 | parser_utils::add_option_to_help(option_name, help_message); |
103 | const std::string pattern = parser_utils::get_pattern(option_name); |
104 | if (pattern.find(str, 0, pattern.size()) == eol) return false; |
105 | str = str + pattern.size(); |
106 | if (*str == '\0') return val = def_val, true; |
107 | return val = process_func(str), true; |
108 | } |
109 | |
110 | template <typename T, typename F> |
111 | static bool parse_cfg(T &vec, const T &def, F process_func, const char *str, |
112 | const std::string &option_name = "cfg" ) { |
113 | static const std::string help |
114 | = "CFG (Default: `f32`)\n Specifies data types `CFG` for " |
115 | "source, weights (if supported) and destination of operation.\n " |
116 | " `CFG` values vary from driver to driver.\n" ; |
117 | return parse_vector_option(vec, def, process_func, str, option_name, help); |
118 | } |
119 | |
120 | template <typename T, typename F> |
121 | static bool parse_alg(T &vec, const T &def, F process_func, const char *str, |
122 | const std::string &option_name = "alg" ) { |
123 | static const std::string help |
124 | = "ALG (Default: depends on driver)\n Specifies operation " |
125 | "algorithm `ALG`.\n `ALG` values vary from driver to " |
126 | "driver.\n" ; |
127 | return parse_vector_option(vec, def, process_func, str, option_name, help); |
128 | } |
129 | |
130 | template <typename T> |
131 | bool parse_subattr(std::vector<T> &vec, const char *str, |
132 | const std::string &option_name, const std::string &help_message = "" ) { |
133 | std::vector<T> def {T()}; |
134 | auto parse_subattr_func = [](const std::string &s) { |
135 | T v; |
136 | SAFE_V(v.from_str(s)); |
137 | return v; |
138 | }; |
139 | return parse_vector_option( |
140 | vec, def, parse_subattr_func, str, option_name, help_message); |
141 | } |
142 | |
143 | template <typename S> |
144 | bool parse_reset(S &settings, const char *str, |
145 | const std::string &option_name = "reset" ) { |
146 | static const std::string help |
147 | = "\n Instructs the driver to reset driver specific options to " |
148 | "their default values.\n Neither global options nor " |
149 | "`--perf-template` option would be reset." ; |
150 | parser_utils::add_option_to_help(option_name, help, false); |
151 | |
152 | const std::string pattern = parser_utils::get_pattern(option_name, false); |
153 | if (pattern.find(str, 0, pattern.size()) == eol) return false; |
154 | settings.reset(); |
155 | return true; |
156 | } |
157 | |
158 | // vector types |
159 | bool parse_dir(std::vector<dir_t> &dir, const std::vector<dir_t> &def_dir, |
160 | const char *str, const std::string &option_name = "dir" ); |
161 | |
162 | bool parse_dt(std::vector<dnnl_data_type_t> &dt, |
163 | const std::vector<dnnl_data_type_t> &def_dt, const char *str, |
164 | const std::string &option_name = "dt" ); |
165 | |
166 | bool parse_multi_dt(std::vector<std::vector<dnnl_data_type_t>> &dt, |
167 | const std::vector<std::vector<dnnl_data_type_t>> &def_dt, |
168 | const char *str, const std::string &option_name = "sdt" ); |
169 | |
170 | bool parse_tag(std::vector<std::string> &tag, |
171 | const std::vector<std::string> &def_tag, const char *str, |
172 | const std::string &option_name = "tag" ); |
173 | |
174 | bool parse_multi_tag(std::vector<std::vector<std::string>> &tag, |
175 | const std::vector<std::vector<std::string>> &def_tag, const char *str, |
176 | const std::string &option_name = "stag" ); |
177 | |
178 | bool parse_mb(std::vector<int64_t> &mb, const std::vector<int64_t> &def_mb, |
179 | const char *str, const std::string &option_name = "mb" ); |
180 | |
181 | bool parse_attr_oscale(std::vector<attr_t::scale_t> &oscale, const char *str, |
182 | const std::string &option_name = "attr-oscale" ); |
183 | |
184 | bool parse_attr_post_ops(std::vector<attr_t::post_ops_t> &po, const char *str, |
185 | const std::string &option_name = "attr-post-ops" ); |
186 | |
187 | bool parse_attr_scales(std::vector<attr_t::arg_scales_t> &scales, |
188 | const char *str, const std::string &option_name = "attr-scales" ); |
189 | |
190 | bool parse_attr_zero_points(std::vector<attr_t::zero_points_t> &zp, |
191 | const char *str, const std::string &option_name = "attr-zero-points" ); |
192 | |
193 | bool parse_attr_scratchpad_mode( |
194 | std::vector<dnnl_scratchpad_mode_t> &scratchpad_mode, |
195 | const std::vector<dnnl_scratchpad_mode_t> &def_scratchpad_mode, |
196 | const char *str, const std::string &option_name = "attr-scratchpad" ); |
197 | |
198 | bool parse_attr_fpmath_mode(std::vector<dnnl_fpmath_mode_t> &fpmath_mode, |
199 | const std::vector<dnnl_fpmath_mode_t> &def_fpmath_mode, const char *str, |
200 | const std::string &option_name = "attr-fpmath" ); |
201 | |
202 | bool parse_ctx_init(std::vector<thr_ctx_t> &ctx, |
203 | const std::vector<thr_ctx_t> &def_ctx, const char *str); |
204 | bool parse_ctx_exe(std::vector<thr_ctx_t> &ctx, |
205 | const std::vector<thr_ctx_t> &def_ctx, const char *str); |
206 | |
207 | bool parse_axis(std::vector<int> &axis, const std::vector<int> &def_axis, |
208 | const char *str, const std::string &option_name = "axis" ); |
209 | |
210 | bool parse_test_pattern_match(const char *&match, const char *str, |
211 | const std::string &option_name = "match" ); |
212 | |
213 | bool parse_inplace(std::vector<bool> &inplace, |
214 | const std::vector<bool> &def_inplace, const char *str, |
215 | const std::string &option_name = "inplace" ); |
216 | |
217 | bool parse_skip_nonlinear(std::vector<bool> &skip, |
218 | const std::vector<bool> &def_skip, const char *str, |
219 | const std::string &option_name = "skip-nonlinear" ); |
220 | |
221 | bool parse_strides(std::vector<vdims_t> &strides, |
222 | const std::vector<vdims_t> &def_strides, const char *str, |
223 | const std::string &option_name = "strides" ); |
224 | |
225 | bool parse_trivial_strides(std::vector<bool> &ts, |
226 | const std::vector<bool> &def_ts, const char *str, |
227 | const std::string &option_name = "trivial-strides" ); |
228 | |
229 | bool parse_scale_policy(std::vector<policy_t> &policy, |
230 | const std::vector<policy_t> &def_policy, const char *str, |
231 | const std::string &option_name = "scaling" ); |
232 | |
233 | // plain types |
234 | bool parse_perf_template(const char *&pt, const char *pt_def, |
235 | const char *pt_csv, const char *str, |
236 | const std::string &option_name = "perf-template" ); |
237 | |
238 | bool parse_batch(const bench_f bench, const char *str, |
239 | const std::string &option_name = "batch" ); |
240 | |
241 | bool parse_help(const char *str, const std::string &option_name = "help" ); |
242 | bool parse_main_help(const char *str, const std::string &option_name = "help" ); |
243 | |
244 | // prb_dims_t type |
245 | // `prb_vdims_t` type is supposed to run on 2+ tensors. However, in rare cases |
246 | // like concat, the library allows a single input. To run a single input, it's |
247 | // now a user's responsibility to define a minimum number of inputs for the |
248 | // driver with `min_inputs` parameter. |
249 | void parse_prb_vdims( |
250 | prb_vdims_t &prb_vdims, const std::string &str, size_t min_inputs = 2); |
251 | void parse_prb_dims(prb_dims_t &prb_dims, const std::string &str); |
252 | |
253 | // service functions |
254 | bool parse_bench_settings(const char *str); |
255 | |
256 | void catch_unknown_options(const char *str); |
257 | |
258 | int parse_last_argument(); |
259 | |
260 | // Function returns a substring of a given string @p `s`, using @p `start_pos` |
261 | // to start a search from this index in string and @p `delim` as a stop symbol |
262 | // and sets a @p `start_pos` to the next symbol after `delim` or to `npos`. |
263 | // E.g. 1) s=apple:juice, start_pos=0, delim=':' |
264 | // get_substr -> apple && start_pos -> 6 |
265 | // 2) s=apple:juice, start_pos=6, delim=':' |
266 | // get_substr -> juice && start_pos -> npos |
267 | // 3) s=apple:juice, start_pos=0, delim=';' |
268 | // get_substr -> apple:juice && start_pos -> npos |
269 | std::string get_substr(const std::string &s, size_t &start_pos, char delim); |
270 | |
271 | } // namespace parser |
272 | |
273 | #endif |
274 | |