1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_PARSER_HPP
18#define UTILS_PARSER_HPP
19
20#include <stdio.h>
21#include <stdlib.h>
22
23#include <sstream>
24#include <string>
25#include <vector>
26
27#include "oneapi/dnnl/dnnl_types.h"
28
29#include "dnn_types.hpp"
30#include "dnnl_debug.hpp"
31#include "tests/test_thread.hpp"
32#include "utils/dims.hpp"
33
34namespace parser {
35
36extern bool last_parsed_is_problem;
37extern const size_t eol;
38extern std::stringstream help_ss;
39
40namespace parser_utils {
41std::string get_pattern(const std::string &option_name, bool with_args = true);
42void add_option_to_help(const std::string &option,
43 const std::string &help_message, bool with_args = true);
44} // namespace parser_utils
45
46template <typename T, typename F>
47static bool parse_vector_str(T &vec, const T &def, F process_func,
48 const std::string &str, char delimeter = ',') {
49 const std::string s = str;
50 if (s.empty()) return vec = def, true;
51
52 vec.clear();
53 for (size_t pos_st = 0, pos_en = s.find_first_of(delimeter, pos_st); true;
54 pos_st = pos_en + 1, pos_en = s.find_first_of(delimeter, pos_st)) {
55 vec.push_back(process_func(s.substr(pos_st, pos_en - pos_st).c_str()));
56 if (pos_en == eol) break;
57 }
58 return true;
59}
60
61template <typename T, typename F>
62static bool parse_multivector_str(std::vector<T> &vec,
63 const std::vector<T> &def, F process_func, const std::string &str,
64 char vector_delim = ',', char element_delim = ':') {
65 auto process_subword = [&](const char *word) {
66 T v, empty_def_v; // defualt value is not expected to be set here
67 // parse vector elements separated by @p element_delim
68 parse_vector_str(v, empty_def_v, process_func, word, element_delim);
69 return v;
70 };
71
72 // parse full vector separated by @p vector_delim
73 return parse_vector_str(vec, def, process_subword, str, vector_delim);
74}
75
76template <typename T, typename F>
77static bool parse_vector_option(T &vec, const T &def, F process_func,
78 const char *str, const std::string &option_name,
79 const std::string &help_message = "") {
80 parser_utils::add_option_to_help(option_name, help_message);
81 const std::string pattern = parser_utils::get_pattern(option_name);
82 if (pattern.find(str, 0, pattern.size()) == eol) return false;
83 return parse_vector_str(vec, def, process_func, str + pattern.size());
84}
85
86template <typename T, typename F>
87static bool parse_multivector_option(std::vector<T> &vec,
88 const std::vector<T> &def, F process_func, const char *str,
89 const std::string &option_name, const std::string &help_message = "",
90 char vector_delim = ',', char element_delim = ':') {
91 parser_utils::add_option_to_help(option_name, help_message);
92 const std::string pattern = parser_utils::get_pattern(option_name);
93 if (pattern.find(str, 0, pattern.size()) == eol) return false;
94 return parse_multivector_str(vec, def, process_func, str + pattern.size(),
95 vector_delim, element_delim);
96}
97
98template <typename T, typename F>
99static bool parse_single_value_option(T &val, const T &def_val, F process_func,
100 const char *str, const std::string &option_name,
101 const std::string &help_message = "") {
102 parser_utils::add_option_to_help(option_name, help_message);
103 const std::string pattern = parser_utils::get_pattern(option_name);
104 if (pattern.find(str, 0, pattern.size()) == eol) return false;
105 str = str + pattern.size();
106 if (*str == '\0') return val = def_val, true;
107 return val = process_func(str), true;
108}
109
110template <typename T, typename F>
111static bool parse_cfg(T &vec, const T &def, F process_func, const char *str,
112 const std::string &option_name = "cfg") {
113 static const std::string help
114 = "CFG (Default: `f32`)\n Specifies data types `CFG` for "
115 "source, weights (if supported) and destination of operation.\n "
116 " `CFG` values vary from driver to driver.\n";
117 return parse_vector_option(vec, def, process_func, str, option_name, help);
118}
119
120template <typename T, typename F>
121static bool parse_alg(T &vec, const T &def, F process_func, const char *str,
122 const std::string &option_name = "alg") {
123 static const std::string help
124 = "ALG (Default: depends on driver)\n Specifies operation "
125 "algorithm `ALG`.\n `ALG` values vary from driver to "
126 "driver.\n";
127 return parse_vector_option(vec, def, process_func, str, option_name, help);
128}
129
130template <typename T>
131bool parse_subattr(std::vector<T> &vec, const char *str,
132 const std::string &option_name, const std::string &help_message = "") {
133 std::vector<T> def {T()};
134 auto parse_subattr_func = [](const std::string &s) {
135 T v;
136 SAFE_V(v.from_str(s));
137 return v;
138 };
139 return parse_vector_option(
140 vec, def, parse_subattr_func, str, option_name, help_message);
141}
142
143template <typename S>
144bool parse_reset(S &settings, const char *str,
145 const std::string &option_name = "reset") {
146 static const std::string help
147 = "\n Instructs the driver to reset driver specific options to "
148 "their default values.\n Neither global options nor "
149 "`--perf-template` option would be reset.";
150 parser_utils::add_option_to_help(option_name, help, false);
151
152 const std::string pattern = parser_utils::get_pattern(option_name, false);
153 if (pattern.find(str, 0, pattern.size()) == eol) return false;
154 settings.reset();
155 return true;
156}
157
158// vector types
159bool parse_dir(std::vector<dir_t> &dir, const std::vector<dir_t> &def_dir,
160 const char *str, const std::string &option_name = "dir");
161
162bool parse_dt(std::vector<dnnl_data_type_t> &dt,
163 const std::vector<dnnl_data_type_t> &def_dt, const char *str,
164 const std::string &option_name = "dt");
165
166bool parse_multi_dt(std::vector<std::vector<dnnl_data_type_t>> &dt,
167 const std::vector<std::vector<dnnl_data_type_t>> &def_dt,
168 const char *str, const std::string &option_name = "sdt");
169
170bool parse_tag(std::vector<std::string> &tag,
171 const std::vector<std::string> &def_tag, const char *str,
172 const std::string &option_name = "tag");
173
174bool parse_multi_tag(std::vector<std::vector<std::string>> &tag,
175 const std::vector<std::vector<std::string>> &def_tag, const char *str,
176 const std::string &option_name = "stag");
177
178bool parse_mb(std::vector<int64_t> &mb, const std::vector<int64_t> &def_mb,
179 const char *str, const std::string &option_name = "mb");
180
181bool parse_attr_oscale(std::vector<attr_t::scale_t> &oscale, const char *str,
182 const std::string &option_name = "attr-oscale");
183
184bool parse_attr_post_ops(std::vector<attr_t::post_ops_t> &po, const char *str,
185 const std::string &option_name = "attr-post-ops");
186
187bool parse_attr_scales(std::vector<attr_t::arg_scales_t> &scales,
188 const char *str, const std::string &option_name = "attr-scales");
189
190bool parse_attr_zero_points(std::vector<attr_t::zero_points_t> &zp,
191 const char *str, const std::string &option_name = "attr-zero-points");
192
193bool parse_attr_scratchpad_mode(
194 std::vector<dnnl_scratchpad_mode_t> &scratchpad_mode,
195 const std::vector<dnnl_scratchpad_mode_t> &def_scratchpad_mode,
196 const char *str, const std::string &option_name = "attr-scratchpad");
197
198bool parse_attr_fpmath_mode(std::vector<dnnl_fpmath_mode_t> &fpmath_mode,
199 const std::vector<dnnl_fpmath_mode_t> &def_fpmath_mode, const char *str,
200 const std::string &option_name = "attr-fpmath");
201
202bool parse_ctx_init(std::vector<thr_ctx_t> &ctx,
203 const std::vector<thr_ctx_t> &def_ctx, const char *str);
204bool parse_ctx_exe(std::vector<thr_ctx_t> &ctx,
205 const std::vector<thr_ctx_t> &def_ctx, const char *str);
206
207bool parse_axis(std::vector<int> &axis, const std::vector<int> &def_axis,
208 const char *str, const std::string &option_name = "axis");
209
210bool parse_test_pattern_match(const char *&match, const char *str,
211 const std::string &option_name = "match");
212
213bool parse_inplace(std::vector<bool> &inplace,
214 const std::vector<bool> &def_inplace, const char *str,
215 const std::string &option_name = "inplace");
216
217bool parse_skip_nonlinear(std::vector<bool> &skip,
218 const std::vector<bool> &def_skip, const char *str,
219 const std::string &option_name = "skip-nonlinear");
220
221bool parse_strides(std::vector<vdims_t> &strides,
222 const std::vector<vdims_t> &def_strides, const char *str,
223 const std::string &option_name = "strides");
224
225bool parse_trivial_strides(std::vector<bool> &ts,
226 const std::vector<bool> &def_ts, const char *str,
227 const std::string &option_name = "trivial-strides");
228
229bool parse_scale_policy(std::vector<policy_t> &policy,
230 const std::vector<policy_t> &def_policy, const char *str,
231 const std::string &option_name = "scaling");
232
233// plain types
234bool parse_perf_template(const char *&pt, const char *pt_def,
235 const char *pt_csv, const char *str,
236 const std::string &option_name = "perf-template");
237
238bool parse_batch(const bench_f bench, const char *str,
239 const std::string &option_name = "batch");
240
241bool parse_help(const char *str, const std::string &option_name = "help");
242bool parse_main_help(const char *str, const std::string &option_name = "help");
243
244// prb_dims_t type
245// `prb_vdims_t` type is supposed to run on 2+ tensors. However, in rare cases
246// like concat, the library allows a single input. To run a single input, it's
247// now a user's responsibility to define a minimum number of inputs for the
248// driver with `min_inputs` parameter.
249void parse_prb_vdims(
250 prb_vdims_t &prb_vdims, const std::string &str, size_t min_inputs = 2);
251void parse_prb_dims(prb_dims_t &prb_dims, const std::string &str);
252
253// service functions
254bool parse_bench_settings(const char *str);
255
256void catch_unknown_options(const char *str);
257
258int parse_last_argument();
259
260// Function returns a substring of a given string @p `s`, using @p `start_pos`
261// to start a search from this index in string and @p `delim` as a stop symbol
262// and sets a @p `start_pos` to the next symbol after `delim` or to `npos`.
263// E.g. 1) s=apple:juice, start_pos=0, delim=':'
264// get_substr -> apple && start_pos -> 6
265// 2) s=apple:juice, start_pos=6, delim=':'
266// get_substr -> juice && start_pos -> npos
267// 3) s=apple:juice, start_pos=0, delim=';'
268// get_substr -> apple:juice && start_pos -> npos
269std::string get_substr(const std::string &s, size_t &start_pos, char delim);
270
271} // namespace parser
272
273#endif
274