1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <cctype> |
19 | #include <cmath> |
20 | #include <stddef.h> |
21 | #include <stdlib.h> |
22 | #include <string.h> |
23 | #include <string> |
24 | |
25 | #include <algorithm> |
26 | #include <iostream> |
27 | #include <sstream> |
28 | |
29 | #include "oneapi/dnnl/dnnl.h" |
30 | |
31 | #include "src/common/math_utils.hpp" |
32 | |
33 | #include "common.hpp" |
34 | #include "dnn_types.hpp" |
35 | #include "dnnl_common.hpp" |
36 | #include "dnnl_debug.hpp" |
37 | #include "dnnl_memory.hpp" |
38 | #include "utils/parser.hpp" |
39 | |
40 | #define BENCHDNN_DNNL_ARG_UNDEF 0 |
41 | |
42 | namespace tag { |
43 | const char *x {"x" }; |
44 | const char *abx {"abx" }; |
45 | const char *axb {"axb" }; |
46 | const char *any {"any" }; |
47 | const char *undef {"undef" }; |
48 | } // namespace tag |
49 | |
50 | std::ostream &operator<<(std::ostream &s, dir_t dir) { |
51 | #define CASE(x) \ |
52 | if (dir == (x)) return s << STRINGIFY(x) |
53 | CASE(FWD_B); |
54 | CASE(FWD_D); |
55 | CASE(FWD_I); |
56 | CASE(BWD_D); |
57 | CASE(BWD_DW); |
58 | CASE(BWD_W); |
59 | CASE(BWD_WB); |
60 | #undef CASE |
61 | SAFE_V(FAIL); |
62 | return s; |
63 | } |
64 | |
65 | std::ostream &operator<<(std::ostream &s, dnnl_data_type_t dt) { |
66 | s << dt2str(dt); |
67 | return s; |
68 | } |
69 | |
70 | std::ostream &operator<<(std::ostream &s, dnnl_engine_kind_t ek) { |
71 | s << engine_kind2str(ek); |
72 | return s; |
73 | } |
74 | |
75 | dir_t str2dir(const char *str) { |
76 | #define CASE(x) \ |
77 | if (!strcasecmp(STRINGIFY(x), str)) return x |
78 | CASE(FWD_D); |
79 | CASE(FWD_I); |
80 | CASE(FWD_B); |
81 | CASE(BWD_D); |
82 | CASE(BWD_W); |
83 | CASE(BWD_WB); |
84 | CASE(BWD_DW); |
85 | #undef CASE |
86 | assert(!"unknown dir" ); |
87 | return DIR_UNDEF; |
88 | } |
89 | |
90 | dnnl_prop_kind_t prop2prop_kind(const dir_t dir) { |
91 | if (dir == FWD_D) return dnnl_forward_training; |
92 | if (dir == FWD_I) return dnnl_forward_inference; |
93 | if (dir == BWD_DW) return dnnl_backward; |
94 | assert(!"unknown dir" ); |
95 | return dnnl_prop_kind_undef; |
96 | } |
97 | |
98 | const char *prop2str(dnnl_prop_kind_t prop) { |
99 | if (prop == dnnl_forward_training) return "FWD_D" ; |
100 | if (prop == dnnl_forward_inference) return "FWD_I" ; |
101 | if (prop == dnnl_backward) return "BWD_DW" ; |
102 | assert(!"unknown prop_kind" ); |
103 | return "unknown prop_kind" ; |
104 | } |
105 | |
106 | const char *data_kind2str(data_kind_t kind) { |
107 | switch (kind) { |
108 | case SRC: return "SRC" ; |
109 | case SRC_1: return "SRC_ADD" ; |
110 | case WEI: return "WEI" ; |
111 | case BIA: return "BIA" ; |
112 | case DST: return "DST" ; |
113 | case ACC: return "ACC" ; |
114 | case MEAN: return "MEAN" ; |
115 | case VAR: return "VAR" ; |
116 | case SC: return "SC" ; |
117 | case SH: return "SH" ; |
118 | case DST_ITER: return "DST_ITER" ; |
119 | case DST_ITER_C: return "DST_ITER_C" ; |
120 | case AUGRU_ATTENTION: return "AUGRU_ATTENTION" ; |
121 | case SRC_ITER: return "SRC_ITER" ; |
122 | case SRC_ITER_C: return "SRC_ITER_C" ; |
123 | case WEI_ITER: return "WEI_ITER" ; |
124 | case WEI_PEEPHOLE: return "WEI_PEEPHOLE" ; |
125 | case WEI_PROJECTION: return "WEI_PROJECTION" ; |
126 | default: assert(!"incorrect data kind" ); |
127 | } |
128 | return "incorrect data kind" ; |
129 | } |
130 | |
131 | static const std::map<int, std::vector<const char *>> supported_args { |
132 | {DNNL_ARG_SRC, {"src" , "src0" }}, |
133 | {DNNL_ARG_SRC_1, {"src1" }}, |
134 | {DNNL_ARG_WEIGHTS, {"wei" }}, |
135 | {DNNL_ARG_DST, {"dst" }}, |
136 | }; |
137 | |
138 | static int str2arg(const std::string &str) { |
139 | for (const auto &arg : supported_args) |
140 | for (const auto &s : arg.second) |
141 | if (str.compare(s) == 0) return arg.first; |
142 | // multiple srcs |
143 | if (str.compare(0, 3, "msrc" )) { |
144 | const auto &str_index = str.substr(4); |
145 | const auto index = stoul(str_index); |
146 | return DNNL_ARG_MULTIPLE_SRC + index; |
147 | } |
148 | return BENCHDNN_DNNL_ARG_UNDEF; |
149 | } |
150 | |
151 | static std::string arg2str(int arg) { |
152 | if (supported_args.find(arg) != supported_args.end()) |
153 | return std::string(supported_args.at(arg)[0]); |
154 | if (arg & DNNL_ARG_MULTIPLE_SRC) { |
155 | std::string msrc("msrc" ); |
156 | const int index = arg - DNNL_ARG_MULTIPLE_SRC; |
157 | return msrc + std::to_string(index); |
158 | } |
159 | assert(!"unknown argument" ); |
160 | return "unknown argument" ; |
161 | } |
162 | |
163 | policy_t attr_t::str2policy(const std::string &str) { |
164 | std::string s(str); |
165 | // s.compare is lexicographical, case matters |
166 | std::transform(s.begin(), s.end(), s.begin(), ::toupper); |
167 | #define CASE(_plc) \ |
168 | if (s.compare(STRINGIFY(_plc)) == 0) return _plc |
169 | CASE(COMMON); |
170 | CASE(PER_OC); |
171 | CASE(PER_DIM_0); |
172 | CASE(PER_DIM_1); |
173 | CASE(PER_DIM_01); |
174 | CASE(PER_DIM_2); |
175 | CASE(PER_DIM_023); |
176 | CASE(PER_DIM_23); |
177 | CASE(PER_DIM_03); |
178 | CASE(PER_DIM_3); |
179 | CASE(PER_TENSOR); |
180 | #undef CASE |
181 | assert(!"unknown attr_t::policy_t policy" ); |
182 | return POLICY_TOTAL; |
183 | } |
184 | |
185 | const char *attr_t::policy2str(policy_t policy) { |
186 | if (policy == COMMON) return "common" ; |
187 | if (policy == PER_OC) return "per_oc" ; |
188 | if (policy == PER_DIM_0) return "per_dim_0" ; |
189 | if (policy == PER_DIM_1) return "per_dim_1" ; |
190 | if (policy == PER_DIM_01) return "per_dim_01" ; |
191 | if (policy == PER_DIM_2) return "per_dim_2" ; |
192 | if (policy == PER_DIM_023) return "per_dim_023" ; |
193 | if (policy == PER_DIM_23) return "per_dim_23" ; |
194 | if (policy == PER_DIM_03) return "per_dim_03" ; |
195 | if (policy == PER_DIM_3) return "per_dim_3" ; |
196 | if (policy == PER_TENSOR) return "per_tensor" ; |
197 | assert(!"unknown attr_t::policy_t policy" ); |
198 | return "unknown attr_t::policy_t policy" ; |
199 | } |
200 | |
201 | int attr_t::get_default_mask(policy_t policy, int arg) { |
202 | switch (policy) { |
203 | case PER_DIM_0: return (1 << 0); |
204 | case PER_OC: |
205 | if (arg == DNNL_ARG_WEIGHTS) return get_default_mask(PER_DIM_0); |
206 | |
207 | case PER_DIM_1: return (1 << 1); |
208 | case PER_DIM_01: return (1 << 0) + (1 << 1); |
209 | case PER_DIM_2: return (1 << 2); |
210 | case PER_DIM_023: return (1 << 0) + (1 << 2) + (1 << 3); |
211 | case PER_DIM_23: return (1 << 2) + (1 << 3); |
212 | case PER_DIM_03: return (1 << 0) + (1 << 3); |
213 | case PER_DIM_3: return (1 << 3); |
214 | case PER_TENSOR: return (1 << DNNL_MAX_NDIMS) - 1; |
215 | case COMMON: return 0; |
216 | default: SAFE(FAIL, CRIT); return 0; |
217 | } |
218 | } |
219 | |
220 | // This function takes input string, extracts float value and runtime, if |
221 | // present, from the string. Updates @value and @runtime with extracted values. |
222 | int parse_value_and_runtime(float &value, bool &runtime, const std::string &s) { |
223 | // process value |
224 | size_t scale_pos = 0; |
225 | try { |
226 | value = std::stof(s, &scale_pos); |
227 | } catch (const std::invalid_argument &) { |
228 | BENCHDNN_PRINT(0, "%s\n%s \'%s\'; %s\n" , |
229 | "Error: output scale or zero point input value is invalid." , |
230 | "Given input:" , s.c_str(), |
231 | "Expected input: \'VAL[*]\'. See help for proper syntax." ); |
232 | exit(1); |
233 | } |
234 | runtime = false; |
235 | if (scale_pos + 1 < s.size()) return FAIL; |
236 | if (scale_pos == s.size()) return OK; |
237 | if (s.back() != '*') return FAIL; |
238 | runtime = true; |
239 | return OK; |
240 | } |
241 | |
242 | int attr_t::scale_t::from_str(const std::string &s) { |
243 | *this = scale_t(); |
244 | if (s.empty()) return OK; |
245 | |
246 | size_t start_pos = 0; |
247 | // process policy |
248 | this->policy = str2policy(parser::get_substr(s, start_pos, ':')); |
249 | if (this->policy == POLICY_TOTAL) return FAIL; |
250 | if (start_pos == std::string::npos) return OK; |
251 | if (start_pos >= s.size()) return FAIL; // to catch dangling ':' |
252 | |
253 | SAFE(parse_value_and_runtime(this->scale, this->runtime, |
254 | parser::get_substr(s, start_pos, ':')), |
255 | WARN); |
256 | if (this->scale < 0) return FAIL; |
257 | return OK; |
258 | } |
259 | |
260 | int attr_t::zero_points_t::from_str(const std::string &s) { |
261 | *this = zero_points_t(); |
262 | if (s.empty()) return OK; |
263 | |
264 | size_t start_pos = 0; |
265 | while (start_pos != std::string::npos) { |
266 | auto subs = parser::get_substr(s, start_pos, '+'); |
267 | size_t subs_pos = 0; |
268 | |
269 | auto arg = str2arg(parser::get_substr(subs, subs_pos, ':')); |
270 | if (arg == BENCHDNN_DNNL_ARG_UNDEF || subs_pos == std::string::npos |
271 | || subs_pos >= subs.size()) |
272 | return FAIL; |
273 | |
274 | auto policy = str2policy(parser::get_substr(subs, subs_pos, ':')); |
275 | if (policy == POLICY_TOTAL || subs_pos == std::string::npos |
276 | || subs_pos >= subs.size()) |
277 | return FAIL; |
278 | |
279 | float zp = 0; |
280 | bool runtime = false; |
281 | SAFE(parse_value_and_runtime( |
282 | zp, runtime, parser::get_substr(subs, subs_pos, '\0')), |
283 | WARN); |
284 | set(arg, policy, static_cast<int>(zp), runtime); |
285 | } |
286 | return OK; |
287 | } |
288 | |
289 | int attr_t::arg_scales_t::from_str(const std::string &s) { |
290 | *this = arg_scales_t(); |
291 | if (s.empty()) return OK; |
292 | |
293 | size_t start_pos = 0; |
294 | while (start_pos != std::string::npos) { |
295 | auto subs = parser::get_substr(s, start_pos, '+'); |
296 | // Special handling for really big float values |
297 | if (subs.back() == 'e') { |
298 | auto subs_add = parser::get_substr(s, start_pos, '+'); |
299 | subs += subs_add; |
300 | } |
301 | size_t subs_pos = 0; |
302 | |
303 | auto arg = str2arg(parser::get_substr(subs, subs_pos, ':')); |
304 | if (arg == BENCHDNN_DNNL_ARG_UNDEF || subs_pos == std::string::npos |
305 | || subs_pos >= s.size()) |
306 | return FAIL; |
307 | |
308 | scale_t arg_scale; |
309 | SAFE(arg_scale.from_str(parser::get_substr(subs, subs_pos, '\0')), |
310 | WARN); |
311 | set(arg, arg_scale); |
312 | } |
313 | return OK; |
314 | } |
315 | |
316 | using pk_t = attr_t::post_ops_t::kind_t; |
317 | |
318 | struct po_table_entry_t { |
319 | pk_t kind; |
320 | std::vector<std::string> kind_names; |
321 | dnnl_alg_kind_t dnnl_kind; |
322 | }; |
323 | |
324 | static po_table_entry_t kind_table[] = { |
325 | // sum |
326 | {pk_t::SUM, {"sum" }, dnnl_alg_kind_undef}, |
327 | // depthwise convolution |
328 | {pk_t::DW, {"dw" }, dnnl_convolution_auto}, |
329 | {pk_t::DW_K3S1P1, {"dw_k3s1p1" }, dnnl_convolution_auto}, |
330 | {pk_t::DW_K3S2P1, {"dw_k3s2p1" }, dnnl_convolution_auto}, |
331 | // eltwise |
332 | {pk_t::ELTWISE_START, {"eltwise_undef" }, dnnl_alg_kind_undef}, |
333 | {pk_t::ABS, {"abs" , "eltwise_abs" }, dnnl_eltwise_abs}, |
334 | {pk_t::CLIP, {"clip" , "eltwise_clip" }, dnnl_eltwise_clip}, |
335 | {pk_t::CLIP_V2, {"clip_v2" , "eltwise_clip_v2" }, dnnl_eltwise_clip_v2}, |
336 | {pk_t::CLIP_V2_DST, {"clip_v2_dst" , "eltwise_clip_v2_use_dst_for_bwd" }, |
337 | dnnl_eltwise_clip_v2_use_dst_for_bwd}, |
338 | {pk_t::ELU, {"elu" , "eltwise_elu" }, dnnl_eltwise_elu}, |
339 | {pk_t::ELU_DST, {"elu_dst" , "eltwise_elu_use_dst_for_bwd" }, |
340 | dnnl_eltwise_elu_use_dst_for_bwd}, |
341 | {pk_t::EXP, {"exp" , "eltwise_exp" }, dnnl_eltwise_exp}, |
342 | {pk_t::EXP_DST, {"exp_dst" , "eltwise_exp_use_dst_for_bwd" }, |
343 | dnnl_eltwise_exp_use_dst_for_bwd}, |
344 | {pk_t::GELU_ERF, {"gelu_erf" , "eltwise_gelu_erf" }, |
345 | dnnl_eltwise_gelu_erf}, |
346 | {pk_t::GELU_TANH, {"gelu_tanh" , "eltwise_gelu_tanh" }, |
347 | dnnl_eltwise_gelu_tanh}, |
348 | {pk_t::HARDSIGMOID, {"hardsigmoid" , "eltwise_hardsigmoid" }, |
349 | dnnl_eltwise_hardsigmoid}, |
350 | {pk_t::HARDSWISH, {"hardswish" , "eltwise_hardswish" }, |
351 | dnnl_eltwise_hardswish}, |
352 | {pk_t::LINEAR, {"linear" , "eltwise_linear" }, dnnl_eltwise_linear}, |
353 | {pk_t::LOG, {"log" , "eltwise_log" }, dnnl_eltwise_log}, |
354 | {pk_t::LOGISTIC, {"logistic" , "eltwise_logistic" }, |
355 | dnnl_eltwise_logistic}, |
356 | {pk_t::LOGISTIC_DST, |
357 | {"logistic_dst" , "eltwise_logistic_use_dst_for_bwd" }, |
358 | dnnl_eltwise_logistic_use_dst_for_bwd}, |
359 | {pk_t::MISH, {"mish" , "eltwise_mish" }, dnnl_eltwise_mish}, |
360 | {pk_t::POW, {"pow" , "eltwise_pow" }, dnnl_eltwise_pow}, |
361 | {pk_t::RELU, {"relu" , "eltwise_relu" }, dnnl_eltwise_relu}, |
362 | {pk_t::RELU_DST, {"relu_dst" , "eltwise_relu_use_dst_for_bwd" }, |
363 | dnnl_eltwise_relu_use_dst_for_bwd}, |
364 | {pk_t::ROUND, {"round" , "eltwise_round" }, dnnl_eltwise_round}, |
365 | {pk_t::SQRT, {"sqrt" , "eltwise_sqrt" }, dnnl_eltwise_sqrt}, |
366 | {pk_t::SQRT_DST, {"sqrt_dst" , "eltwise_sqrt_use_dst_for_bwd" }, |
367 | dnnl_eltwise_sqrt_use_dst_for_bwd}, |
368 | {pk_t::SQUARE, {"square" , "eltwise_square" }, dnnl_eltwise_square}, |
369 | {pk_t::SRELU, {"soft_relu" , "eltwise_soft_relu" , "srelu" }, |
370 | dnnl_eltwise_soft_relu}, |
371 | {pk_t::SWISH, {"swish" , "eltwise_swish" }, dnnl_eltwise_swish}, |
372 | {pk_t::TANH, {"tanh" , "eltwise_tanh" }, dnnl_eltwise_tanh}, |
373 | {pk_t::TANH_DST, {"tanh_dst" , "eltwise_tanh_use_dst_for_bwd" }, |
374 | dnnl_eltwise_tanh_use_dst_for_bwd}, |
375 | {pk_t::ELTWISE_END, {"eltwise_undef" }, dnnl_alg_kind_undef}, |
376 | // binary |
377 | {pk_t::BINARY_START, {"binary_undef" }, dnnl_alg_kind_undef}, |
378 | {pk_t::ADD, {"add" , "binary_add" }, dnnl_binary_add}, |
379 | {pk_t::DIV, {"div" , "binary_div" }, dnnl_binary_div}, |
380 | {pk_t::EQ, {"eq" , "binary_eq" }, dnnl_binary_eq}, |
381 | {pk_t::GE, {"ge" , "binary_ge" }, dnnl_binary_ge}, |
382 | {pk_t::GT, {"gt" , "binary_gt" }, dnnl_binary_gt}, |
383 | {pk_t::LE, {"le" , "binary_le" }, dnnl_binary_le}, |
384 | {pk_t::LT, {"lt" , "binary_lt" }, dnnl_binary_lt}, |
385 | {pk_t::MAX, {"max" , "binary_max" }, dnnl_binary_max}, |
386 | {pk_t::MIN, {"min" , "binary_min" }, dnnl_binary_min}, |
387 | {pk_t::MUL, {"mul" , "binary_mul" }, dnnl_binary_mul}, |
388 | {pk_t::NE, {"ne" , "binary_ne" }, dnnl_binary_ne}, |
389 | {pk_t::SUB, {"sub" , "binary_sub" }, dnnl_binary_sub}, |
390 | {pk_t::BINARY_END, {"binary_undef" }, dnnl_alg_kind_undef}, |
391 | // prelu |
392 | {pk_t::PRELU, {"prelu" }, dnnl_alg_kind_undef}, |
393 | // guard entry |
394 | {pk_t::KIND_TOTAL, {"kind_undef" }, dnnl_alg_kind_undef}}; |
395 | |
396 | pk_t attr_t::post_ops_t::str2kind(const std::string &str) { |
397 | std::string s(str); |
398 | // string::operator== is lexicographical, case matters |
399 | std::transform(s.begin(), s.end(), s.begin(), ::tolower); |
400 | for (const auto &e : kind_table) { |
401 | for (const auto &name : e.kind_names) { |
402 | if (s == name) return e.kind; |
403 | } |
404 | } |
405 | BENCHDNN_PRINT(0, "%s\'%s\' %s\n" , "Error: " , str.c_str(), |
406 | "kind of post operation entry was not recognized." ); |
407 | |
408 | const auto table_size = sizeof(kind_table) / sizeof(*kind_table); |
409 | return kind_table[table_size - 1].kind; |
410 | } |
411 | |
412 | const char *attr_t::post_ops_t::kind2str(pk_t kind) { |
413 | for (const auto &e : kind_table) { |
414 | if (e.kind == kind) return e.kind_names[0].c_str(); |
415 | } |
416 | assert(!"unknown attr::post_ops::kind" ); |
417 | const auto table_size = sizeof(kind_table) / sizeof(*kind_table); |
418 | return kind_table[table_size - 1].kind_names[0].c_str(); |
419 | } |
420 | |
421 | dnnl_alg_kind_t attr_t::post_ops_t::kind2dnnl_kind(pk_t kind) { |
422 | for (const auto &e : kind_table) { |
423 | if (e.kind == kind) return e.dnnl_kind; |
424 | } |
425 | assert(!"unknown attr::post_ops::kind" ); |
426 | const auto table_size = sizeof(kind_table) / sizeof(*kind_table); |
427 | return kind_table[table_size - 1].dnnl_kind; |
428 | } |
429 | |
430 | std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks() const { |
431 | std::vector<std::pair<int, int>> v_masks; |
432 | for (int idx = 0; idx < len(); ++idx) { |
433 | const auto &e = this->entry[idx]; |
434 | policy_t policy = policy_t::COMMON; |
435 | int arg = BENCHDNN_DNNL_ARG_UNDEF; |
436 | if (e.is_binary_kind()) { |
437 | policy = e.binary.policy; |
438 | arg = DNNL_ARG_SRC_1; |
439 | } else if (e.is_prelu_kind()) { |
440 | policy = e.prelu.policy; |
441 | arg = DNNL_ARG_WEIGHTS; |
442 | } else |
443 | continue; |
444 | |
445 | const auto mask = attr_t::get_default_mask(policy); |
446 | v_masks.emplace_back(std::make_pair( |
447 | DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg, mask)); |
448 | } |
449 | return v_masks; |
450 | } |
451 | |
452 | int attr_t::post_ops_t::from_str(const std::string &s) { |
453 | *this = post_ops_t(); |
454 | if (s.empty()) return OK; |
455 | |
456 | size_t start_pos = 0; |
457 | while (start_pos != std::string::npos) { |
458 | auto subs = parser::get_substr(s, start_pos, '+'); |
459 | size_t subs_pos = 0; |
460 | |
461 | auto kind = str2kind(parser::get_substr(subs, subs_pos, ':')); |
462 | if (kind == KIND_TOTAL) return FAIL; |
463 | |
464 | entry.emplace_back(kind); |
465 | if (subs_pos == std::string::npos) continue; |
466 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
467 | |
468 | auto &e = entry.back(); |
469 | if (e.is_sum_kind()) { |
470 | e.sum.scale = std::stof(parser::get_substr(subs, subs_pos, ':')); |
471 | if (subs_pos == std::string::npos) continue; |
472 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
473 | |
474 | auto zp_str = parser::get_substr(subs, subs_pos, ':'); |
475 | e.sum.zero_point = std::stoi(zp_str); |
476 | if (subs_pos == std::string::npos) continue; |
477 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
478 | if (std::to_string(e.sum.zero_point) != zp_str) return FAIL; |
479 | |
480 | e.sum.dt = str2dt(parser::get_substr(subs, subs_pos, ':').c_str()); |
481 | // sum dt, if specified, should be defined |
482 | if (e.sum.dt == dnnl_data_type_undef) return FAIL; |
483 | } else if (e.is_convolution_kind()) { |
484 | if (kind == DW) { |
485 | // `DW` has input of `dw:kXsYpZ`, while rest have `dw_k3sXp1`. |
486 | const auto str_dw_params |
487 | = parser::get_substr(subs, subs_pos, ':'); |
488 | size_t pos = 0, idx = 0; |
489 | |
490 | pos += idx; |
491 | if (str_dw_params[pos] != 'k') return FAIL; |
492 | e.convolution.kernel = std::stoi(&str_dw_params[++pos], &idx); |
493 | |
494 | pos += idx; |
495 | if (str_dw_params[pos] != 's') return FAIL; |
496 | e.convolution.stride = std::stoi(&str_dw_params[++pos], &idx); |
497 | |
498 | pos += idx; |
499 | if (str_dw_params[pos] != 'p') return FAIL; |
500 | e.convolution.padding = std::stoi(&str_dw_params[++pos]); |
501 | |
502 | if (subs_pos == std::string::npos) continue; |
503 | } |
504 | |
505 | e.convolution.dst_dt |
506 | = str2dt(parser::get_substr(subs, subs_pos, ':').c_str()); |
507 | if (e.convolution.dst_dt == dnnl_data_type_undef) return FAIL; |
508 | if (subs_pos == std::string::npos) continue; |
509 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
510 | |
511 | auto scale_str = parser::get_substr(subs, subs_pos, '+'); |
512 | SAFE(e.convolution.wei_scale.from_str(scale_str), WARN); |
513 | size_t dst_scale_pos = 0; |
514 | for (int i = 0; i < 2; ++i) |
515 | dst_scale_pos = scale_str.find(":" , dst_scale_pos + 1); |
516 | if (dst_scale_pos != std::string::npos) { |
517 | auto dst_scale_str = scale_str.substr(dst_scale_pos + 1); |
518 | SAFE(e.convolution.dst_scale.from_str(dst_scale_str), WARN); |
519 | } |
520 | } else if (e.is_eltwise_kind()) { |
521 | e.eltwise.alpha |
522 | = std::stof(parser::get_substr(subs, subs_pos, ':')); |
523 | if (subs_pos == std::string::npos) continue; |
524 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
525 | |
526 | e.eltwise.beta = std::stof(parser::get_substr(subs, subs_pos, ':')); |
527 | if (subs_pos == std::string::npos) continue; |
528 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
529 | } else if (e.is_binary_kind()) { |
530 | e.binary.src1_dt |
531 | = str2dt(parser::get_substr(subs, subs_pos, ':').c_str()); |
532 | if (e.binary.src1_dt == dnnl_data_type_undef) return FAIL; |
533 | if (subs_pos == std::string::npos) continue; |
534 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
535 | |
536 | e.binary.policy |
537 | = str2policy(parser::get_substr(subs, subs_pos, ':')); |
538 | if (e.binary.policy == POLICY_TOTAL) return FAIL; |
539 | if (subs_pos == std::string::npos) continue; |
540 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
541 | |
542 | e.binary.tag = parser::get_substr(subs, subs_pos, ':'); |
543 | SAFE(check_tag(e.binary.tag), WARN); |
544 | } else if (e.is_prelu_kind()) { |
545 | e.prelu.policy |
546 | = str2policy(parser::get_substr(subs, subs_pos, ':')); |
547 | if (e.prelu.policy == POLICY_TOTAL) return FAIL; |
548 | if (subs_pos == std::string::npos) continue; |
549 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
550 | } |
551 | if (subs_pos == std::string::npos) continue; |
552 | if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':' |
553 | } |
554 | return OK; |
555 | } |
556 | |
557 | bool attr_t::is_def(bool skip_fpmath) const { |
558 | return scales.is_def() && zero_points.is_def() && post_ops.is_def() |
559 | && scratchpad_mode == dnnl_scratchpad_mode_library |
560 | && IMPLICATION( |
561 | !skip_fpmath, fpmath_mode == dnnl_fpmath_mode_strict); |
562 | } |
563 | |
564 | int attr_t::post_ops_t::find(pk_t kind, int start, int stop) const { |
565 | if (stop == -1) stop = len(); |
566 | stop = MIN2(stop, len()); |
567 | for (int idx = start; idx < stop; ++idx) |
568 | if (entry[idx].kind == kind) return idx; |
569 | return -1; |
570 | } |
571 | |
572 | bool attr_t::post_ops_t::entry_t::is_sum_kind() const { |
573 | return kind == SUM; |
574 | } |
575 | bool attr_t::post_ops_t::entry_t::is_convolution_kind() const { |
576 | return kind == DW || kind == DW_K3S1P1 || kind == DW_K3S2P1; |
577 | } |
578 | bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const { |
579 | return kind > ELTWISE_START && kind < ELTWISE_END; |
580 | } |
581 | bool attr_t::post_ops_t::entry_t::is_binary_kind() const { |
582 | return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END; |
583 | } |
584 | bool attr_t::post_ops_t::entry_t::is_prelu_kind() const { |
585 | return kind == PRELU; |
586 | } |
587 | |
588 | int attr_t::post_ops_t::convolution_index() const { |
589 | for (int i = 0; i < len(); ++i) { |
590 | if (entry[i].is_convolution_kind()) return i; |
591 | } |
592 | return -1; |
593 | } |
594 | |
595 | int attr_t::post_ops_t::eltwise_index() const { |
596 | for (int i = 0; i < len(); ++i) { |
597 | if (entry[i].is_eltwise_kind()) return i; |
598 | } |
599 | return -1; |
600 | } |
601 | |
602 | int attr_t::post_ops_t::binary_index() const { |
603 | for (int i = 0; i < len(); ++i) { |
604 | if (entry[i].is_binary_kind()) return i; |
605 | } |
606 | return -1; |
607 | } |
608 | |
609 | int attr_t::post_ops_t::prelu_index() const { |
610 | for (int i = 0; i < len(); ++i) { |
611 | if (entry[i].is_prelu_kind()) return i; |
612 | } |
613 | return -1; |
614 | } |
615 | |
616 | std::ostream &operator<<(std::ostream &s, const policy_t &policy) { |
617 | s << attr_t::policy2str(policy); |
618 | return s; |
619 | } |
620 | |
621 | std::ostream &operator<<(std::ostream &s, const attr_t::scale_t &scale) { |
622 | s << scale.policy << ":" << scale.scale; |
623 | if (scale.runtime) s << '*'; |
624 | return s; |
625 | } |
626 | |
627 | std::ostream &operator<<( |
628 | std::ostream &s, const attr_t::zero_points_t &zero_points) { |
629 | const char *delim = "" ; |
630 | for (const auto &point : zero_points.points) { |
631 | s << delim; |
632 | s << arg2str(point.first) << ":" << point.second.policy << ":" |
633 | << point.second.value; |
634 | if (point.second.runtime) s << '*'; |
635 | delim = "+" ; |
636 | } |
637 | |
638 | return s; |
639 | } |
640 | |
641 | std::ostream &operator<<(std::ostream &s, const attr_t::arg_scales_t &scales) { |
642 | const char *delim = "" ; |
643 | for (const auto &v : scales.scales) { |
644 | if (!v.second.is_def()) { |
645 | s << delim; |
646 | s << arg2str(v.first) << ":" << v.second; |
647 | delim = "+" ; |
648 | } |
649 | } |
650 | return s; |
651 | } |
652 | |
653 | std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t::kind_t &k) { |
654 | s << attr_t::post_ops_t::kind2str(k); |
655 | return s; |
656 | } |
657 | |
658 | std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t &post_ops) { |
659 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
660 | if (idx > 0) s << "+" ; |
661 | |
662 | const auto &e = post_ops.entry[idx]; |
663 | s << e.kind; |
664 | |
665 | if (e.is_sum_kind()) { |
666 | if (e.sum.scale != 1.0f || e.sum.zero_point != 0 |
667 | || e.sum.dt != dnnl_data_type_undef) |
668 | s << ":" << e.sum.scale; |
669 | if (e.sum.zero_point != 0 || e.sum.dt != dnnl_data_type_undef) |
670 | s << ":" << e.sum.zero_point; |
671 | if (e.sum.dt != dnnl_data_type_undef) s << ":" << e.sum.dt; |
672 | } else if (e.is_convolution_kind()) { |
673 | if (e.kind == pk_t::DW) { |
674 | s << ":k" << e.convolution.kernel << "s" << e.convolution.stride |
675 | << "p" << e.convolution.padding; |
676 | } |
677 | const auto &c_ws = e.convolution.wei_scale; |
678 | const auto &c_ds = e.convolution.dst_scale; |
679 | if (e.convolution.dst_dt != dnnl_f32 || !c_ws.is_def() |
680 | || !c_ds.is_def()) |
681 | s << ":" << e.convolution.dst_dt; |
682 | if (!c_ws.is_def() || !c_ds.is_def()) s << ":" << c_ws; |
683 | if (!c_ds.is_def()) s << ":" << c_ds; |
684 | } else if (e.is_eltwise_kind()) { |
685 | if (e.eltwise.beta != 0.f) |
686 | s << ":" << e.eltwise.alpha << ":" << e.eltwise.beta; |
687 | else if (e.eltwise.alpha != 0.f) |
688 | s << ":" << e.eltwise.alpha; |
689 | } else if (e.is_binary_kind()) { |
690 | s << ":" << e.binary.src1_dt; |
691 | if (e.binary.policy != policy_t::COMMON) { |
692 | s << ":" << e.binary.policy; |
693 | if (attr_t::get_default_mask(e.binary.policy) >= 4) |
694 | s << ":" << e.binary.tag; |
695 | } |
696 | } else if (e.is_prelu_kind()) { |
697 | if (e.prelu.policy != policy_t::COMMON) { |
698 | s << ":" << e.prelu.policy; |
699 | } |
700 | } else { |
701 | assert(!"unknown kind" ); |
702 | s << "unknown_kind" ; |
703 | } |
704 | } |
705 | |
706 | return s; |
707 | } |
708 | |
709 | std::ostream &operator<<(std::ostream &s, dnnl_scratchpad_mode_t sm) { |
710 | s << scratchpad_mode2str(sm); |
711 | return s; |
712 | } |
713 | |
714 | std::ostream &operator<<(std::ostream &s, dnnl_fpmath_mode_t fm) { |
715 | s << fpmath_mode2str(fm); |
716 | return s; |
717 | } |
718 | |
719 | std::ostream &operator<<(std::ostream &s, const attr_t &attr) { |
720 | if (!attr.is_def()) { |
721 | if (!attr.scales.is_def()) s << "--attr-scales=" << attr.scales << " " ; |
722 | if (!attr.zero_points.is_def()) |
723 | s << "--attr-zero-points=" << attr.zero_points << " " ; |
724 | if (!attr.post_ops.is_def()) |
725 | s << "--attr-post-ops=" << attr.post_ops << " " ; |
726 | if (attr.scratchpad_mode != dnnl_scratchpad_mode_library) |
727 | s << "--attr-scratchpad=" << attr.scratchpad_mode << " " ; |
728 | if (attr.fpmath_mode != dnnl_fpmath_mode_strict) |
729 | s << "--attr-fpmath=" << attr.fpmath_mode << " " ; |
730 | } |
731 | return s; |
732 | } |
733 | |
734 | std::ostream &operator<<(std::ostream &s, bench_mode_t mode) { |
735 | if (is_bench_mode(RUN) && !(is_bench_mode(CORR) || is_bench_mode(PERF))) |
736 | s << "R" ; |
737 | if (is_bench_mode(CORR)) s << "C" ; |
738 | if (is_bench_mode(PERF)) s << "P" ; |
739 | if (is_bench_mode(LIST)) s << "L" ; |
740 | if (is_bench_mode(PROF)) s << "O" ; |
741 | return s; |
742 | } |
743 | |
744 | std::ostream &operator<<(std::ostream &s, memory_kind_ext_t memory_kind) { |
745 | switch (memory_kind) { |
746 | case memory_kind_ext_t::usm: s << "usm" ; break; |
747 | case memory_kind_ext_t::buffer: s << "buffer" ; break; |
748 | case memory_kind_ext_t::usm_device: s << "usm_device" ; break; |
749 | case memory_kind_ext_t::usm_shared: s << "usm_shared" ; break; |
750 | default: assert(!"unexpected" ); break; |
751 | } |
752 | return s; |
753 | } |
754 | |
755 | std::ostream &dump_global_params(std::ostream &s) { |
756 | s << "--" << driver_name << " " ; |
757 | if (canonical) s << "--canonical=" << bool2str(canonical) << " " ; |
758 | if (canonical || engine_tgt_kind != dnnl_cpu) { |
759 | s << "--engine=" << engine_tgt_kind; |
760 | if (engine_index != 0) s << ":" << engine_index; |
761 | s << " " ; |
762 | } |
763 | if (canonical || fast_ref_gpu != true) |
764 | s << "--fast-ref-gpu=" << bool2str(fast_ref_gpu) << " " ; |
765 | if (!skip_impl.empty()) s << "--skip-impl=" << skip_impl << " " ; |
766 | if (canonical || mem_check != true) |
767 | s << "--mem-check=" << bool2str(mem_check) << " " ; |
768 | if (canonical || allow_enum_tags_only != true) |
769 | s << "--allow-enum-tags-only=" << bool2str(allow_enum_tags_only) << " " ; |
770 | if (canonical || hints.get() != isa_hints_t::none) |
771 | s << "--cpu-isa-hints=" << isa_hints_t::hints2str(hints) << " " ; |
772 | if (canonical || bench_mode != CORR) s << "--mode=" << bench_mode << " " ; |
773 | if (canonical || attr_same_pd_check != false) |
774 | s << "--attr-same-pd-check=" << bool2str(attr_same_pd_check) << " " ; |
775 | #if defined(DNNL_WITH_SYCL) || DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
776 | if (canonical || memory_kind != default_memory_kind) |
777 | s << "--memory-kind=" << memory_kind << " " ; |
778 | #endif |
779 | |
780 | return s; |
781 | } |
782 | |
783 | dnnl_engine_kind_t str2engine_kind(const char *str) { |
784 | const char *param = "cpu" ; |
785 | if (!strncasecmp(param, str, strlen(param))) return dnnl_cpu; |
786 | |
787 | param = "gpu" ; |
788 | if (!strncasecmp(param, str, strlen(param))) return dnnl_gpu; |
789 | |
790 | assert(!"not expected" ); |
791 | return dnnl_cpu; |
792 | } |
793 | |
794 | dnnl_scratchpad_mode_t str2scratchpad_mode(const char *str) { |
795 | const char *param = "library" ; |
796 | if (!strncasecmp(param, str, strlen(param))) |
797 | return dnnl_scratchpad_mode_library; |
798 | |
799 | param = "user" ; |
800 | if (!strncasecmp(param, str, strlen(param))) |
801 | return dnnl_scratchpad_mode_user; |
802 | |
803 | assert(!"not expected" ); |
804 | return dnnl_scratchpad_mode_library; |
805 | } |
806 | |
807 | dnnl_fpmath_mode_t str2fpmath_mode(const char *str) { |
808 | if (std::strcmp(str, "" ) == 0) { |
809 | dnnl_fpmath_mode_t ret; |
810 | dnnl_get_default_fpmath_mode(&ret); |
811 | return ret; |
812 | } |
813 | |
814 | #define CASE(fpm) \ |
815 | param = #fpm; \ |
816 | if (!strncasecmp(param, str, strlen(param))) return dnnl_fpmath_mode_##fpm; |
817 | |
818 | const char *param; |
819 | |
820 | CASE(strict); |
821 | CASE(bf16); |
822 | CASE(f16); |
823 | CASE(tf32); |
824 | CASE(any); |
825 | |
826 | assert(!"not expected" ); |
827 | return dnnl_fpmath_mode_strict; |
828 | |
829 | #undef CASE |
830 | } |
831 | |
832 | void attr_args_t::prepare_scales(const attr_t &attr, int arg, const void *vals, |
833 | int64_t count, int mask) { |
834 | insert(arg, vals, count, mask, attr.scales.get(arg).runtime); |
835 | } |
836 | |
837 | struct post_ops_rhs_tensor_entry_t { |
838 | dnnl_data_type_t dt; |
839 | policy_t policy; |
840 | std::string tag; |
841 | int arg_attr_mask; |
842 | }; |
843 | |
844 | namespace { |
845 | |
846 | post_ops_rhs_tensor_entry_t get_po_rhs_tensor_entry( |
847 | const attr_t::post_ops_t::entry_t &entry) { |
848 | if (entry.is_prelu_kind()) { |
849 | const auto &prelu = entry.prelu; |
850 | return {dnnl_f32, prelu.policy, tag::axb, DNNL_ARG_WEIGHTS}; |
851 | } else if (entry.is_binary_kind()) { |
852 | const auto &binary = entry.binary; |
853 | return {binary.src1_dt, binary.policy, binary.tag, DNNL_ARG_SRC_1}; |
854 | } |
855 | |
856 | return post_ops_rhs_tensor_entry_t {}; |
857 | } |
858 | |
859 | } // namespace |
860 | |
861 | int attr_args_t::prepare_post_ops_mds( |
862 | const attr_t &attr, int ndims, const dnnl_dims_t dims) { |
863 | const auto &po = attr.post_ops; |
864 | // iterate over all post ops and prepare md for each binary |
865 | for (int idx = 0; idx < po.len(); ++idx) { |
866 | const auto &e = po.entry[idx]; |
867 | if (e.is_binary_kind() || e.is_prelu_kind()) { |
868 | |
869 | const auto po_rhs_tensor_entry = get_po_rhs_tensor_entry(e); |
870 | const int mask |
871 | = attr_t::get_default_mask(po_rhs_tensor_entry.policy); |
872 | |
873 | // deduce binary, prelu dims based on input policy |
874 | dnnl_dims_t rhs_tensor_dims = {}; |
875 | for (auto d = 0; d < ndims; ++d) |
876 | rhs_tensor_dims[d] = (!(mask & (1 << d))) ? 1 : dims[d]; |
877 | |
878 | auto rhs_tensor_desc = dnn_mem_t::init_md(ndims, rhs_tensor_dims, |
879 | po_rhs_tensor_entry.dt, po_rhs_tensor_entry.tag); |
880 | mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) |
881 | | po_rhs_tensor_entry.arg_attr_mask), |
882 | std::move(rhs_tensor_desc)); |
883 | } |
884 | } |
885 | |
886 | return OK; |
887 | } |
888 | |
889 | void attr_args_t::prepare_dw_post_op( |
890 | const attr_t &attr, dnnl_data_type_t wei_dt, dnnl_data_type_t bia_dt) { |
891 | const int dw_idx = attr.post_ops.convolution_index(); |
892 | if (dw_idx == -1) return; |
893 | |
894 | dw_entry.wei_dt = wei_dt; |
895 | dw_entry.bia_dt = bia_dt; |
896 | } |
897 | |
898 | dnnl_primitive_attr_t create_dnnl_attr( |
899 | const attr_t &attr, const attr_args_t &attr_args) { |
900 | dnnl_primitive_attr_t dnnl_attr = nullptr; |
901 | DNN_SAFE_V(dnnl_primitive_attr_create(&dnnl_attr)); |
902 | |
903 | if (!attr.scales.is_def()) { |
904 | const auto &as = attr.scales; |
905 | for (const auto &arg : as.scales) { |
906 | const int arg_name = arg.first; |
907 | if (as.is_def(arg_name)) continue; |
908 | |
909 | if (arg_name == DNNL_ARG_WEIGHTS |
910 | && arg.second.policy == policy_t::PER_OC |
911 | && !attr_args.get(arg_name).is_def()) { |
912 | const auto &e = attr_args.get(arg_name); |
913 | // Only RT scales are supported. |
914 | SAFE_V(e.runtime ? OK : FAIL); |
915 | // Only common policy is supported in the library at this point |
916 | int mask = e.mask; |
917 | |
918 | DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask( |
919 | dnnl_attr, arg_name, mask)); |
920 | } else { |
921 | const auto &e = arg.second; |
922 | // Only RT scales are supported. |
923 | SAFE_V(e.runtime ? OK : FAIL); |
924 | // Only common policy is supported in the library at this point |
925 | int mask = attr_t::get_default_mask(e.policy, arg_name); |
926 | |
927 | DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask( |
928 | dnnl_attr, arg_name, mask)); |
929 | } |
930 | } |
931 | } |
932 | |
933 | if (!attr.zero_points.is_def()) { |
934 | const auto &zp = attr.zero_points; |
935 | for (const auto &arg : zp.points) { |
936 | const auto arg_name = arg.first; |
937 | if (zp.is_def(arg_name)) continue; |
938 | |
939 | const auto &e = arg.second; |
940 | // Only RT scales are supported. |
941 | SAFE_V(e.runtime ? OK : FAIL); |
942 | // Only common policy/single RT value are supported in the library |
943 | // at this point |
944 | int mask = attr_t::get_default_mask(e.policy); |
945 | |
946 | DNN_SAFE_V(dnnl_primitive_attr_set_zero_points_mask( |
947 | dnnl_attr, arg_name, mask)); |
948 | } |
949 | } |
950 | |
951 | if (!attr.post_ops.is_def()) { |
952 | dnnl_post_ops_t ops; |
953 | DNN_SAFE_V(dnnl_post_ops_create(&ops)); |
954 | |
955 | const auto &po = attr.post_ops; |
956 | for (int idx = 0; idx < po.len(); ++idx) { |
957 | const auto &e = po.entry[idx]; |
958 | if (e.is_sum_kind()) { |
959 | DNN_SAFE_V(dnnl_post_ops_append_sum( |
960 | ops, e.sum.scale, e.sum.zero_point, e.sum.dt)); |
961 | } else if (e.is_convolution_kind()) { |
962 | const auto wei_dt = attr_args.get_dw_arg(DNNL_ARG_WEIGHTS); |
963 | const auto bia_dt = attr_args.get_dw_arg(DNNL_ARG_BIAS); |
964 | |
965 | DNN_SAFE_V(dnnl_post_ops_append_dw(ops, wei_dt, bia_dt, |
966 | e.convolution.dst_dt, e.convolution.kernel, |
967 | e.convolution.stride, e.convolution.padding)); |
968 | |
969 | const auto &wei_policy = e.convolution.wei_scale.policy; |
970 | int wei_mask = attr_t::get_default_mask( |
971 | wei_policy, DNNL_ARG_WEIGHTS); |
972 | // dw conv always has group dim |
973 | if (wei_mask) wei_mask = 1 << wei_mask; |
974 | if (e.convolution.wei_scale.runtime) |
975 | DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(dnnl_attr, |
976 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS, |
977 | wei_mask)); |
978 | |
979 | const auto &dst_policy = e.convolution.dst_scale.policy; |
980 | int dst_mask |
981 | = attr_t::get_default_mask(dst_policy, DNNL_ARG_DST); |
982 | if (e.convolution.dst_scale.runtime) |
983 | DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(dnnl_attr, |
984 | DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST, dst_mask)); |
985 | |
986 | } else if (e.is_eltwise_kind()) { |
987 | DNN_SAFE_V(dnnl_post_ops_append_eltwise( |
988 | ops, e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta)); |
989 | } else if (e.is_binary_kind()) { |
990 | const auto &src1_md = attr_args.get_md( |
991 | (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1)); |
992 | assert(query_md_ndims(src1_md) != 0); |
993 | DNN_SAFE_V(dnnl_post_ops_append_binary( |
994 | ops, e.binary.alg, src1_md)); |
995 | } else if (e.is_prelu_kind()) { |
996 | const auto &policy = e.prelu.policy; |
997 | const auto mask = attr_t::get_default_mask(policy); |
998 | DNN_SAFE_V(dnnl_post_ops_append_prelu(ops, mask)); |
999 | } else { |
1000 | assert(!"unknown attr::post_ops::kind" ); |
1001 | } |
1002 | } |
1003 | DNN_SAFE_V(dnnl_primitive_attr_set_post_ops(dnnl_attr, ops)); |
1004 | auto c_ops = query_post_ops(dnnl_attr); |
1005 | SAFE_V(dnnl_post_ops_len(c_ops) == po.len() ? OK : FAIL); |
1006 | |
1007 | DNN_SAFE_V(dnnl_post_ops_destroy(ops)); |
1008 | } |
1009 | |
1010 | DNN_SAFE_V(dnnl_primitive_attr_set_scratchpad_mode( |
1011 | dnnl_attr, attr.scratchpad_mode)); |
1012 | |
1013 | DNN_SAFE_V( |
1014 | dnnl_primitive_attr_set_fpmath_mode(dnnl_attr, attr.fpmath_mode)); |
1015 | |
1016 | return dnnl_attr; |
1017 | } |
1018 | |
1019 | // Exception free version of std::stoi, sets idx to 0 and returns 0 in case of |
1020 | // error. |
1021 | static int stoi_safe(const std::string &s, size_t *idx) { |
1022 | if (s.empty() || !std::isdigit(s[0])) { |
1023 | *idx = 0; |
1024 | return 0; |
1025 | } |
1026 | return std::stoi(s, idx); |
1027 | } |
1028 | |
1029 | static bool is_abc_tag(const std::string &tag) { |
1030 | if (tag == tag::undef || tag == tag::any) return true; |
1031 | |
1032 | bool mask[DNNL_MAX_NDIMS] = {}; |
1033 | for (auto &c : tag) { |
1034 | if (!std::isalpha(c)) continue; |
1035 | int idx = std::tolower(c) - 'a'; |
1036 | if (idx < 0 || idx >= DNNL_MAX_NDIMS) return false; |
1037 | mask[idx] = true; |
1038 | } |
1039 | // Check there are no gaps, e.g. [1 1 1 1 0 0 ...]. |
1040 | for (int i = 0; i < DNNL_MAX_NDIMS; i++) { |
1041 | if (mask[i]) continue; |
1042 | for (int j = i + 1; j < DNNL_MAX_NDIMS; j++) |
1043 | if (mask[j]) return false; |
1044 | break; |
1045 | } |
1046 | return true; |
1047 | } |
1048 | |
1049 | int check_abc_tag(const std::string &tag_, bool check_enum_tags_only) { |
1050 | if (tag_.empty()) return FAIL; |
1051 | if (!is_abc_tag(tag_)) return FAIL; |
1052 | if (check_enum_tags_only) { |
1053 | if (str2fmt_tag(tag_.c_str()) == dnnl_format_tag_last) return FAIL; |
1054 | return OK; |
1055 | } |
1056 | |
1057 | enum class dim_state_t { undef = 0, upper, lower, lower_with_block }; |
1058 | dim_state_t dim_states[DNNL_MAX_NDIMS] = {}; |
1059 | bool in_inner_block = false; |
1060 | auto tag = tag_; |
1061 | while (!tag.empty()) { |
1062 | // Parse block size if presented. |
1063 | size_t idx; |
1064 | int block = stoi_safe(tag, &idx); |
1065 | if (block == 0 && idx != 0) return FAIL; |
1066 | if (idx == 0) block = 0; |
1067 | if (block > 0) in_inner_block = true; |
1068 | |
1069 | // Move to the first position after the block. |
1070 | tag = tag.substr(idx); |
1071 | if (tag.empty()) return FAIL; |
1072 | |
1073 | char c = tag[0]; |
1074 | bool is_lower = ('a' <= c && c <= 'a' + DNNL_MAX_NDIMS - 1); |
1075 | bool is_upper = ('A' <= c && c <= 'A' + DNNL_MAX_NDIMS - 1); |
1076 | if (!is_lower && !is_upper) return FAIL; |
1077 | |
1078 | // Uppercase cannot be with block. |
1079 | if (is_upper && block != 0) return FAIL; |
1080 | // Block sizes are required within inner block. |
1081 | if (block == 0 && in_inner_block) return FAIL; |
1082 | |
1083 | // Check rules related to lowercase/uppercase/block order. |
1084 | int dim_idx = std::tolower(c) - 'a'; |
1085 | dim_state_t prev_state = dim_states[dim_idx]; |
1086 | dim_state_t cur_state = is_upper |
1087 | ? dim_state_t::upper |
1088 | : block != 0 ? dim_state_t::lower_with_block |
1089 | : dim_state_t::lower; |
1090 | |
1091 | switch (cur_state) { |
1092 | case dim_state_t::upper: |
1093 | case dim_state_t::lower: |
1094 | // Letter without block must be the first. |
1095 | if (prev_state != dim_state_t::undef) return FAIL; |
1096 | break; |
1097 | case dim_state_t::lower_with_block: |
1098 | // Letter with block must be after uppercase or after a letter |
1099 | // with block. |
1100 | if (prev_state != dim_state_t::upper |
1101 | && prev_state != dim_state_t::lower_with_block) |
1102 | return FAIL; |
1103 | break; |
1104 | default: assert(!"not expected" ); |
1105 | } |
1106 | |
1107 | // Update state, move to the next position. |
1108 | dim_states[dim_idx] = cur_state; |
1109 | tag = tag.substr(1); |
1110 | } |
1111 | |
1112 | for (int i = 0; i < DNNL_MAX_NDIMS; i++) { |
1113 | // Uppercase letter must be followed by lowercase. |
1114 | if (dim_states[i] == dim_state_t::upper) return FAIL; |
1115 | |
1116 | // Ensure there are no gaps (e.g. acd). |
1117 | if (dim_states[i] == dim_state_t::undef) { |
1118 | for (int j = i + 1; j < DNNL_MAX_NDIMS; j++) |
1119 | if (dim_states[j] != dim_state_t::undef) return FAIL; |
1120 | break; |
1121 | } |
1122 | } |
1123 | |
1124 | return OK; |
1125 | } |
1126 | |
1127 | static std::string trim_letter(const std::string &tag_, char c) { |
1128 | auto tag = tag_; |
1129 | for (size_t pos = tag.find(c); pos != std::string::npos; |
1130 | pos = tag.find(c)) { |
1131 | tag.replace(pos, 1, "" ); |
1132 | if (pos == 0) return tag; |
1133 | |
1134 | pos--; |
1135 | while (std::isdigit(tag[pos])) { |
1136 | tag.replace(pos, 1, "" ); |
1137 | if (pos == 0) break; |
1138 | pos--; |
1139 | } |
1140 | } |
1141 | return tag; |
1142 | } |
1143 | |
1144 | // Tries to map a tag to an abc-tag according to a logical tag. For example: |
1145 | // nchw -> abcd. |
1146 | static std::string try_map_tag( |
1147 | const std::string &logical_tag, const std::string &tag, int *nmatched) { |
1148 | // Check if all the required letters are presented. |
1149 | for (auto &c : logical_tag) { |
1150 | if (std::toupper(c) == c |
1151 | && tag.find(std::tolower(c)) == std::string::npos) |
1152 | return {}; |
1153 | } |
1154 | |
1155 | // Check that all letters are known and assign indices to letters. |
1156 | int logical_indices[DNNL_MAX_NDIMS] = {}; |
1157 | for (auto &c : tag) { |
1158 | if (!std::isalpha(c)) continue; |
1159 | |
1160 | auto lower_pos = logical_tag.find(std::tolower(c)); |
1161 | auto upper_pos = logical_tag.find(std::toupper(c)); |
1162 | auto pos = (lower_pos == std::string::npos ? upper_pos : lower_pos); |
1163 | if (pos == std::string::npos) return {}; |
1164 | |
1165 | logical_indices[pos] = 1; |
1166 | } |
1167 | |
1168 | for (int i = 0, idx = 0; i < (int)logical_tag.size(); i++) { |
1169 | if (logical_indices[i] == 0) continue; |
1170 | logical_indices[i] = idx++; |
1171 | } |
1172 | |
1173 | (*nmatched)++; |
1174 | std::string mapped_tag = tag; |
1175 | for (int i = 0; i < (int)tag.size(); i++) { |
1176 | char c = tag[i]; |
1177 | if (!std::isalpha(tag[i])) continue; |
1178 | auto pos = logical_tag.find(std::tolower(c)); |
1179 | if (pos == std::string::npos) pos = logical_tag.find(std::toupper(c)); |
1180 | |
1181 | mapped_tag[i] |
1182 | = (char)(tag[i] - std::tolower(c) + 'a' + logical_indices[pos]); |
1183 | } |
1184 | return mapped_tag; |
1185 | } |
1186 | |
1187 | // Maps a tag to an abc-tag. |
1188 | static std::string map_tag_letters(const std::string &tag) { |
1189 | int nmatched = 0; |
1190 | |
1191 | // Mapping rules: |
1192 | // - Uppercase letters are mandatory |
1193 | // - Lowercase letters are optional |
1194 | auto tag_goidhw = try_map_tag("GOIdhw" , tag, &nmatched); |
1195 | auto tag_oidhw = try_map_tag("OIdhw" , tag, &nmatched); |
1196 | auto tag_ncdhw = try_map_tag("NCdhw" , tag, &nmatched); |
1197 | auto tag_tnc = try_map_tag("TNc" , tag, &nmatched); |
1198 | auto tag_ldnc = try_map_tag("LDNC" , tag, &nmatched); |
1199 | auto tag_ldigo = try_map_tag("LDigO" , tag, &nmatched); |
1200 | |
1201 | if (nmatched == 0) return tag; |
1202 | if (nmatched > 1) assert(!"Not expected: ambiguous tag." ); |
1203 | |
1204 | if (!tag_goidhw.empty()) return tag_goidhw; |
1205 | if (!tag_oidhw.empty()) return tag_oidhw; |
1206 | if (!tag_ncdhw.empty()) return tag_ncdhw; |
1207 | if (!tag_tnc.empty()) return tag_tnc; |
1208 | if (!tag_ldnc.empty()) return tag_ldnc; |
1209 | if (!tag_ldigo.empty()) return tag_ldigo; |
1210 | |
1211 | return tag; |
1212 | } |
1213 | |
1214 | std::string trim_tag(const std::string &tag, int ndims) { |
1215 | int mask = 0; |
1216 | for (int d = 0; d < ndims; d++) { |
1217 | mask += (1 << d); |
1218 | } |
1219 | return trim_tag_by_mask(tag, mask); |
1220 | } |
1221 | |
1222 | std::string trim_tag_by_mask(const std::string &tag, int mask) { |
1223 | std::string trimmed_tag = tag; |
1224 | int ndims_saved = 0; |
1225 | for (char c = 'a', d = 0; c < 'a' + (char)(DNNL_MAX_NDIMS); c++, d++) { |
1226 | if (!(mask & (1 << d))) { |
1227 | trimmed_tag = trim_letter(trimmed_tag, c); |
1228 | trimmed_tag = trim_letter(trimmed_tag, std::toupper(c)); |
1229 | } else { |
1230 | ndims_saved++; |
1231 | } |
1232 | } |
1233 | |
1234 | // Mask may operate over non-consecutive dimensions. The piece below will |
1235 | // make trimmed_tag consist of consecutive dimensions starting from "a" or |
1236 | // "A". E.g., mask = 2 + 8 = 10, trimmed_tag will contain "b" and "d" |
1237 | // letters, and will be converted into one with "a" and "b". |
1238 | int mask_copy = mask; |
1239 | for (int i = 0; i < ndims_saved; i++) { |
1240 | int dist_to_a = 0; |
1241 | while (mask_copy % 2 == 0) { |
1242 | mask_copy /= 2; |
1243 | dist_to_a++; |
1244 | } |
1245 | mask_copy /= 2; |
1246 | if (dist_to_a == 0) continue; |
1247 | |
1248 | for (size_t j = 0; j < trimmed_tag.size(); j++) { |
1249 | char str_j = trimmed_tag[j]; |
1250 | if (std::isalpha(str_j) && std::tolower(str_j) > 'a' + i) { |
1251 | std::string rep_str(1, str_j - dist_to_a); |
1252 | trimmed_tag.replace(j, 1, rep_str); |
1253 | } |
1254 | } |
1255 | } |
1256 | |
1257 | return trimmed_tag; |
1258 | } |
1259 | |
1260 | std::string normalize_tag(const std::string &tag_, int ndims) { |
1261 | std::string tag = tag_; |
1262 | if (tag == tag::undef || tag == tag::any || ndims == 0) return tag; |
1263 | if (tag == tag::x) { |
1264 | if (ndims >= 0) assert(ndims == 1); |
1265 | return "a" ; |
1266 | } |
1267 | |
1268 | // Handle meta-tags (abx, axb, etc). |
1269 | auto pos = tag.find("x" ); |
1270 | if (pos != std::string::npos) { |
1271 | // Non-grouped tags will start `x` from `c`, but grouped will most of |
1272 | // times start `x` from `d`. |
1273 | char start_x = 'c'; |
1274 | for (char c = 'a' + DNNL_MAX_NDIMS - 1; c >= 'b'; c--) { |
1275 | if (tag.find(c) != std::string::npos) { |
1276 | start_x = c + 1; |
1277 | break; |
1278 | } |
1279 | } |
1280 | // Adjust ndims if they are not specified. |
1281 | int meta_ndims = (ndims == -1 ? (start_x - 'a' + 1) : ndims); |
1282 | std::string tail; |
1283 | for (int i = 0; i < meta_ndims - (start_x - 'a'); i++) |
1284 | tail += (start_x + i); |
1285 | return trim_tag(tag.replace(pos, 1, tail), meta_ndims); |
1286 | } |
1287 | |
1288 | return map_tag_letters(tag); |
1289 | } |
1290 | |
1291 | int check_tag(const std::string &tag_, bool check_enum_tags_only) { |
1292 | auto tag = normalize_tag(tag_); |
1293 | if (tag == tag::undef || tag == tag::any) return OK; |
1294 | return check_abc_tag(tag, check_enum_tags_only); |
1295 | } |
1296 | |
1297 | void maybe_scale(const attr_t &attr, float &d, const float *scales, int64_t c, |
1298 | int arg, bool opposite_scale) { |
1299 | if (attr.scales.is_def()) return; |
1300 | |
1301 | const auto &e = attr.scales.get(arg); |
1302 | if (!e.is_def()) { |
1303 | int64_t idx = e.policy == policy_t::COMMON ? 0 : c; |
1304 | if (opposite_scale) |
1305 | d /= scales[idx]; |
1306 | else |
1307 | d *= scales[idx]; |
1308 | } |
1309 | } |
1310 | |
1311 | void maybe_zero_point(const attr_t &attr, float &d, const int32_t *zero_points, |
1312 | int64_t c, int arg, bool opposite_zero_point) { |
1313 | if (attr.zero_points.is_def()) return; |
1314 | |
1315 | const auto &e = attr.zero_points.get(arg); |
1316 | if (!e.is_def()) { |
1317 | const int idx = e.policy == policy_t::COMMON ? 0 : c; |
1318 | const int zp_sign = opposite_zero_point ? -1 : 1; |
1319 | d -= zp_sign * zero_points[idx]; |
1320 | } |
1321 | } |
1322 | |
1323 | float compute_eltwise_fwd(pk_t kind, float src, float alpha, float beta) { |
1324 | // don't compute on nan, propagate it |
1325 | if (std::isnan(src)) return NAN; |
1326 | |
1327 | using namespace dnnl::impl::math; |
1328 | |
1329 | switch (kind) { |
1330 | case pk_t::RELU: return relu_fwd(src, alpha); |
1331 | case pk_t::TANH: return tanh_fwd(src); |
1332 | case pk_t::ELU: return elu_fwd(src, alpha); |
1333 | case pk_t::SQUARE: return square_fwd(src); |
1334 | case pk_t::ABS: return abs_fwd(src); |
1335 | case pk_t::SQRT: return sqrt_fwd(src); |
1336 | case pk_t::LINEAR: return linear_fwd(src, alpha, beta); |
1337 | case pk_t::SRELU: return soft_relu_fwd(src, alpha); |
1338 | case pk_t::MISH: return mish_fwd(src); |
1339 | case pk_t::LOGISTIC: return logistic_fwd(src); |
1340 | case pk_t::EXP: return exp_fwd(src); |
1341 | case pk_t::GELU_TANH: return gelu_tanh_fwd(src); |
1342 | case pk_t::SWISH: return swish_fwd(src, alpha); |
1343 | case pk_t::LOG: return log_fwd(src); |
1344 | case pk_t::CLIP: return clip_fwd(src, alpha, beta); |
1345 | case pk_t::CLIP_V2: return clip_v2_fwd(src, alpha, beta); |
1346 | case pk_t::POW: return pow_fwd(src, alpha, beta); |
1347 | case pk_t::GELU_ERF: return gelu_erf_fwd(src); |
1348 | case pk_t::ROUND: return round_fwd(src); |
1349 | case pk_t::HARDSWISH: return hardswish_fwd(src, alpha, beta); |
1350 | case pk_t::HARDSIGMOID: return hardsigmoid_fwd(src, alpha, beta); |
1351 | case pk_t::RELU_DST: return relu_fwd(src, alpha); |
1352 | case pk_t::TANH_DST: return tanh_fwd(src); |
1353 | case pk_t::ELU_DST: return elu_fwd(src, alpha); |
1354 | case pk_t::SQRT_DST: return sqrt_fwd(src); |
1355 | case pk_t::LOGISTIC_DST: return logistic_fwd(src); |
1356 | case pk_t::EXP_DST: return exp_fwd(src); |
1357 | case pk_t::CLIP_V2_DST: return clip_v2_fwd(src, alpha, beta); |
1358 | |
1359 | default: assert(!"unknown attr::post_ops::kind" ); |
1360 | }; |
1361 | return NAN; |
1362 | } |
1363 | |
1364 | float compute_eltwise_bwd( |
1365 | pk_t kind, float d_dst, float src, float alpha, float beta) { |
1366 | using namespace dnnl::impl::math; |
1367 | |
1368 | switch (kind) { |
1369 | case pk_t::RELU: return relu_bwd(d_dst, src, alpha); |
1370 | case pk_t::TANH: return tanh_bwd(d_dst, src); |
1371 | case pk_t::ELU: return elu_bwd(d_dst, src, alpha); |
1372 | case pk_t::SQUARE: return square_bwd(d_dst, src); |
1373 | case pk_t::ABS: return abs_bwd(d_dst, src); |
1374 | case pk_t::SQRT: return sqrt_bwd(d_dst, src); |
1375 | case pk_t::LINEAR: return linear_bwd(d_dst, src, alpha, beta); |
1376 | case pk_t::SRELU: return soft_relu_bwd(d_dst, src, alpha); |
1377 | case pk_t::MISH: return mish_bwd(d_dst, src); |
1378 | case pk_t::LOGISTIC: return logistic_bwd(d_dst, src); |
1379 | case pk_t::EXP: return exp_bwd(d_dst, src); |
1380 | case pk_t::GELU_TANH: return gelu_tanh_bwd(d_dst, src); |
1381 | case pk_t::SWISH: return swish_bwd(d_dst, src, alpha); |
1382 | case pk_t::LOG: return log_bwd(d_dst, src); |
1383 | case pk_t::CLIP: return clip_bwd(d_dst, src, alpha, beta); |
1384 | case pk_t::CLIP_V2: return clip_v2_bwd(d_dst, src, alpha, beta); |
1385 | case pk_t::POW: return pow_bwd(d_dst, src, alpha, beta); |
1386 | case pk_t::GELU_ERF: return gelu_erf_bwd(d_dst, src); |
1387 | case pk_t::HARDSWISH: return hardswish_bwd(d_dst, src, alpha, beta); |
1388 | case pk_t::HARDSIGMOID: return hardsigmoid_bwd(d_dst, src, alpha, beta); |
1389 | |
1390 | case pk_t::RELU_DST: return relu_bwd_use_dst(d_dst, src, alpha); |
1391 | case pk_t::TANH_DST: return tanh_bwd_use_dst(d_dst, src); |
1392 | case pk_t::ELU_DST: return elu_bwd_use_dst(d_dst, src, alpha); |
1393 | case pk_t::SQRT_DST: return sqrt_bwd_use_dst(d_dst, src); |
1394 | case pk_t::LOGISTIC_DST: return logistic_bwd_use_dst(d_dst, src); |
1395 | case pk_t::EXP_DST: return exp_bwd_use_dst(d_dst, src); |
1396 | case pk_t::CLIP_V2_DST: |
1397 | return clip_v2_bwd_use_dst(d_dst, src, alpha, beta); |
1398 | |
1399 | default: assert(!"unknown attr::post_ops::kind" ); |
1400 | } |
1401 | return NAN; |
1402 | } |
1403 | |
1404 | float compute_binary(pk_t kind, float src0, float src1) { |
1405 | // don't compute on nan, propagate it |
1406 | if (std::isnan(src0) || std::isnan(src1)) return NAN; |
1407 | |
1408 | if (kind == pk_t::ADD) { |
1409 | return src0 + src1; |
1410 | } else if (kind == pk_t::MUL) { |
1411 | return src0 * src1; |
1412 | } else if (kind == pk_t::MAX) { |
1413 | return MAX2(src0, src1); |
1414 | } else if (kind == pk_t::MIN) { |
1415 | return MIN2(src0, src1); |
1416 | } else if (kind == pk_t::DIV) { |
1417 | return src0 / src1; |
1418 | } else if (kind == pk_t::SUB) { |
1419 | return src0 - src1; |
1420 | } else if (kind == pk_t::GE) { |
1421 | return src0 >= src1; |
1422 | } else if (kind == pk_t::GT) { |
1423 | return src0 > src1; |
1424 | } else if (kind == pk_t::LE) { |
1425 | return src0 <= src1; |
1426 | } else if (kind == pk_t::LT) { |
1427 | return src0 < src1; |
1428 | } else if (kind == pk_t::EQ) { |
1429 | return src0 == src1; |
1430 | } else if (kind == pk_t::NE) { |
1431 | return src0 != src1; |
1432 | } else { |
1433 | assert(!"operation not supported!" ); |
1434 | } |
1435 | return NAN; |
1436 | } |
1437 | |
1438 | void maybe_post_ops(const attr_t &attr, float &val, float sum_val, |
1439 | const std::vector<float> &v_po_vals) { |
1440 | using namespace dnnl::impl::math; |
1441 | |
1442 | auto it_po = v_po_vals.begin(); |
1443 | const auto &po = attr.post_ops; |
1444 | for (int idx = 0; idx < po.len(); ++idx) { |
1445 | const auto &e = po.entry[idx]; |
1446 | |
1447 | if (e.is_sum_kind()) { |
1448 | val += e.sum.scale * (sum_val - e.sum.zero_point); |
1449 | } else if (e.is_convolution_kind()) { |
1450 | continue; |
1451 | } else if (e.is_eltwise_kind()) { |
1452 | const auto &a = e.eltwise.alpha; |
1453 | const auto &b = e.eltwise.beta; |
1454 | val = compute_eltwise_fwd(e.kind, val, a, b); |
1455 | } else if (e.is_binary_kind()) { |
1456 | val = compute_binary(e.kind, val, *it_po); |
1457 | it_po++; |
1458 | } else if (e.is_prelu_kind()) { |
1459 | val = val > 0 ? val : val * (*it_po); |
1460 | it_po++; |
1461 | } |
1462 | } |
1463 | } |
1464 | |
1465 | void update_cpu_ref_attrs(attr_t &attr, dnnl_data_type_t new_dt) { |
1466 | auto &po = attr.post_ops; |
1467 | for (int idx = 0; idx < po.len(); ++idx) { |
1468 | auto &e = po.entry[idx]; |
1469 | if (!e.is_binary_kind()) continue; |
1470 | |
1471 | e.binary.src1_dt = new_dt; |
1472 | e.binary.tag = tag::abx; // Hardcoded in setup_binary_po as well. |
1473 | } |
1474 | } |
1475 | |
1476 | #undef BENCHDNN_DNNL_ARG_UNDEF |
1477 | |