1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <cctype>
19#include <cmath>
20#include <stddef.h>
21#include <stdlib.h>
22#include <string.h>
23#include <string>
24
25#include <algorithm>
26#include <iostream>
27#include <sstream>
28
29#include "oneapi/dnnl/dnnl.h"
30
31#include "src/common/math_utils.hpp"
32
33#include "common.hpp"
34#include "dnn_types.hpp"
35#include "dnnl_common.hpp"
36#include "dnnl_debug.hpp"
37#include "dnnl_memory.hpp"
38#include "utils/parser.hpp"
39
40#define BENCHDNN_DNNL_ARG_UNDEF 0
41
42namespace tag {
43const char *x {"x"};
44const char *abx {"abx"};
45const char *axb {"axb"};
46const char *any {"any"};
47const char *undef {"undef"};
48} // namespace tag
49
50std::ostream &operator<<(std::ostream &s, dir_t dir) {
51#define CASE(x) \
52 if (dir == (x)) return s << STRINGIFY(x)
53 CASE(FWD_B);
54 CASE(FWD_D);
55 CASE(FWD_I);
56 CASE(BWD_D);
57 CASE(BWD_DW);
58 CASE(BWD_W);
59 CASE(BWD_WB);
60#undef CASE
61 SAFE_V(FAIL);
62 return s;
63}
64
65std::ostream &operator<<(std::ostream &s, dnnl_data_type_t dt) {
66 s << dt2str(dt);
67 return s;
68}
69
70std::ostream &operator<<(std::ostream &s, dnnl_engine_kind_t ek) {
71 s << engine_kind2str(ek);
72 return s;
73}
74
75dir_t str2dir(const char *str) {
76#define CASE(x) \
77 if (!strcasecmp(STRINGIFY(x), str)) return x
78 CASE(FWD_D);
79 CASE(FWD_I);
80 CASE(FWD_B);
81 CASE(BWD_D);
82 CASE(BWD_W);
83 CASE(BWD_WB);
84 CASE(BWD_DW);
85#undef CASE
86 assert(!"unknown dir");
87 return DIR_UNDEF;
88}
89
90dnnl_prop_kind_t prop2prop_kind(const dir_t dir) {
91 if (dir == FWD_D) return dnnl_forward_training;
92 if (dir == FWD_I) return dnnl_forward_inference;
93 if (dir == BWD_DW) return dnnl_backward;
94 assert(!"unknown dir");
95 return dnnl_prop_kind_undef;
96}
97
98const char *prop2str(dnnl_prop_kind_t prop) {
99 if (prop == dnnl_forward_training) return "FWD_D";
100 if (prop == dnnl_forward_inference) return "FWD_I";
101 if (prop == dnnl_backward) return "BWD_DW";
102 assert(!"unknown prop_kind");
103 return "unknown prop_kind";
104}
105
106const char *data_kind2str(data_kind_t kind) {
107 switch (kind) {
108 case SRC: return "SRC";
109 case SRC_1: return "SRC_ADD";
110 case WEI: return "WEI";
111 case BIA: return "BIA";
112 case DST: return "DST";
113 case ACC: return "ACC";
114 case MEAN: return "MEAN";
115 case VAR: return "VAR";
116 case SC: return "SC";
117 case SH: return "SH";
118 case DST_ITER: return "DST_ITER";
119 case DST_ITER_C: return "DST_ITER_C";
120 case AUGRU_ATTENTION: return "AUGRU_ATTENTION";
121 case SRC_ITER: return "SRC_ITER";
122 case SRC_ITER_C: return "SRC_ITER_C";
123 case WEI_ITER: return "WEI_ITER";
124 case WEI_PEEPHOLE: return "WEI_PEEPHOLE";
125 case WEI_PROJECTION: return "WEI_PROJECTION";
126 default: assert(!"incorrect data kind");
127 }
128 return "incorrect data kind";
129}
130
131static const std::map<int, std::vector<const char *>> supported_args {
132 {DNNL_ARG_SRC, {"src", "src0"}},
133 {DNNL_ARG_SRC_1, {"src1"}},
134 {DNNL_ARG_WEIGHTS, {"wei"}},
135 {DNNL_ARG_DST, {"dst"}},
136};
137
138static int str2arg(const std::string &str) {
139 for (const auto &arg : supported_args)
140 for (const auto &s : arg.second)
141 if (str.compare(s) == 0) return arg.first;
142 // multiple srcs
143 if (str.compare(0, 3, "msrc")) {
144 const auto &str_index = str.substr(4);
145 const auto index = stoul(str_index);
146 return DNNL_ARG_MULTIPLE_SRC + index;
147 }
148 return BENCHDNN_DNNL_ARG_UNDEF;
149}
150
151static std::string arg2str(int arg) {
152 if (supported_args.find(arg) != supported_args.end())
153 return std::string(supported_args.at(arg)[0]);
154 if (arg & DNNL_ARG_MULTIPLE_SRC) {
155 std::string msrc("msrc");
156 const int index = arg - DNNL_ARG_MULTIPLE_SRC;
157 return msrc + std::to_string(index);
158 }
159 assert(!"unknown argument");
160 return "unknown argument";
161}
162
163policy_t attr_t::str2policy(const std::string &str) {
164 std::string s(str);
165 // s.compare is lexicographical, case matters
166 std::transform(s.begin(), s.end(), s.begin(), ::toupper);
167#define CASE(_plc) \
168 if (s.compare(STRINGIFY(_plc)) == 0) return _plc
169 CASE(COMMON);
170 CASE(PER_OC);
171 CASE(PER_DIM_0);
172 CASE(PER_DIM_1);
173 CASE(PER_DIM_01);
174 CASE(PER_DIM_2);
175 CASE(PER_DIM_023);
176 CASE(PER_DIM_23);
177 CASE(PER_DIM_03);
178 CASE(PER_DIM_3);
179 CASE(PER_TENSOR);
180#undef CASE
181 assert(!"unknown attr_t::policy_t policy");
182 return POLICY_TOTAL;
183}
184
185const char *attr_t::policy2str(policy_t policy) {
186 if (policy == COMMON) return "common";
187 if (policy == PER_OC) return "per_oc";
188 if (policy == PER_DIM_0) return "per_dim_0";
189 if (policy == PER_DIM_1) return "per_dim_1";
190 if (policy == PER_DIM_01) return "per_dim_01";
191 if (policy == PER_DIM_2) return "per_dim_2";
192 if (policy == PER_DIM_023) return "per_dim_023";
193 if (policy == PER_DIM_23) return "per_dim_23";
194 if (policy == PER_DIM_03) return "per_dim_03";
195 if (policy == PER_DIM_3) return "per_dim_3";
196 if (policy == PER_TENSOR) return "per_tensor";
197 assert(!"unknown attr_t::policy_t policy");
198 return "unknown attr_t::policy_t policy";
199}
200
201int attr_t::get_default_mask(policy_t policy, int arg) {
202 switch (policy) {
203 case PER_DIM_0: return (1 << 0);
204 case PER_OC:
205 if (arg == DNNL_ARG_WEIGHTS) return get_default_mask(PER_DIM_0);
206
207 case PER_DIM_1: return (1 << 1);
208 case PER_DIM_01: return (1 << 0) + (1 << 1);
209 case PER_DIM_2: return (1 << 2);
210 case PER_DIM_023: return (1 << 0) + (1 << 2) + (1 << 3);
211 case PER_DIM_23: return (1 << 2) + (1 << 3);
212 case PER_DIM_03: return (1 << 0) + (1 << 3);
213 case PER_DIM_3: return (1 << 3);
214 case PER_TENSOR: return (1 << DNNL_MAX_NDIMS) - 1;
215 case COMMON: return 0;
216 default: SAFE(FAIL, CRIT); return 0;
217 }
218}
219
220// This function takes input string, extracts float value and runtime, if
221// present, from the string. Updates @value and @runtime with extracted values.
222int parse_value_and_runtime(float &value, bool &runtime, const std::string &s) {
223 // process value
224 size_t scale_pos = 0;
225 try {
226 value = std::stof(s, &scale_pos);
227 } catch (const std::invalid_argument &) {
228 BENCHDNN_PRINT(0, "%s\n%s \'%s\'; %s\n",
229 "Error: output scale or zero point input value is invalid.",
230 "Given input:", s.c_str(),
231 "Expected input: \'VAL[*]\'. See help for proper syntax.");
232 exit(1);
233 }
234 runtime = false;
235 if (scale_pos + 1 < s.size()) return FAIL;
236 if (scale_pos == s.size()) return OK;
237 if (s.back() != '*') return FAIL;
238 runtime = true;
239 return OK;
240}
241
242int attr_t::scale_t::from_str(const std::string &s) {
243 *this = scale_t();
244 if (s.empty()) return OK;
245
246 size_t start_pos = 0;
247 // process policy
248 this->policy = str2policy(parser::get_substr(s, start_pos, ':'));
249 if (this->policy == POLICY_TOTAL) return FAIL;
250 if (start_pos == std::string::npos) return OK;
251 if (start_pos >= s.size()) return FAIL; // to catch dangling ':'
252
253 SAFE(parse_value_and_runtime(this->scale, this->runtime,
254 parser::get_substr(s, start_pos, ':')),
255 WARN);
256 if (this->scale < 0) return FAIL;
257 return OK;
258}
259
260int attr_t::zero_points_t::from_str(const std::string &s) {
261 *this = zero_points_t();
262 if (s.empty()) return OK;
263
264 size_t start_pos = 0;
265 while (start_pos != std::string::npos) {
266 auto subs = parser::get_substr(s, start_pos, '+');
267 size_t subs_pos = 0;
268
269 auto arg = str2arg(parser::get_substr(subs, subs_pos, ':'));
270 if (arg == BENCHDNN_DNNL_ARG_UNDEF || subs_pos == std::string::npos
271 || subs_pos >= subs.size())
272 return FAIL;
273
274 auto policy = str2policy(parser::get_substr(subs, subs_pos, ':'));
275 if (policy == POLICY_TOTAL || subs_pos == std::string::npos
276 || subs_pos >= subs.size())
277 return FAIL;
278
279 float zp = 0;
280 bool runtime = false;
281 SAFE(parse_value_and_runtime(
282 zp, runtime, parser::get_substr(subs, subs_pos, '\0')),
283 WARN);
284 set(arg, policy, static_cast<int>(zp), runtime);
285 }
286 return OK;
287}
288
289int attr_t::arg_scales_t::from_str(const std::string &s) {
290 *this = arg_scales_t();
291 if (s.empty()) return OK;
292
293 size_t start_pos = 0;
294 while (start_pos != std::string::npos) {
295 auto subs = parser::get_substr(s, start_pos, '+');
296 // Special handling for really big float values
297 if (subs.back() == 'e') {
298 auto subs_add = parser::get_substr(s, start_pos, '+');
299 subs += subs_add;
300 }
301 size_t subs_pos = 0;
302
303 auto arg = str2arg(parser::get_substr(subs, subs_pos, ':'));
304 if (arg == BENCHDNN_DNNL_ARG_UNDEF || subs_pos == std::string::npos
305 || subs_pos >= s.size())
306 return FAIL;
307
308 scale_t arg_scale;
309 SAFE(arg_scale.from_str(parser::get_substr(subs, subs_pos, '\0')),
310 WARN);
311 set(arg, arg_scale);
312 }
313 return OK;
314}
315
316using pk_t = attr_t::post_ops_t::kind_t;
317
318struct po_table_entry_t {
319 pk_t kind;
320 std::vector<std::string> kind_names;
321 dnnl_alg_kind_t dnnl_kind;
322};
323
324static po_table_entry_t kind_table[] = {
325 // sum
326 {pk_t::SUM, {"sum"}, dnnl_alg_kind_undef},
327 // depthwise convolution
328 {pk_t::DW, {"dw"}, dnnl_convolution_auto},
329 {pk_t::DW_K3S1P1, {"dw_k3s1p1"}, dnnl_convolution_auto},
330 {pk_t::DW_K3S2P1, {"dw_k3s2p1"}, dnnl_convolution_auto},
331 // eltwise
332 {pk_t::ELTWISE_START, {"eltwise_undef"}, dnnl_alg_kind_undef},
333 {pk_t::ABS, {"abs", "eltwise_abs"}, dnnl_eltwise_abs},
334 {pk_t::CLIP, {"clip", "eltwise_clip"}, dnnl_eltwise_clip},
335 {pk_t::CLIP_V2, {"clip_v2", "eltwise_clip_v2"}, dnnl_eltwise_clip_v2},
336 {pk_t::CLIP_V2_DST, {"clip_v2_dst", "eltwise_clip_v2_use_dst_for_bwd"},
337 dnnl_eltwise_clip_v2_use_dst_for_bwd},
338 {pk_t::ELU, {"elu", "eltwise_elu"}, dnnl_eltwise_elu},
339 {pk_t::ELU_DST, {"elu_dst", "eltwise_elu_use_dst_for_bwd"},
340 dnnl_eltwise_elu_use_dst_for_bwd},
341 {pk_t::EXP, {"exp", "eltwise_exp"}, dnnl_eltwise_exp},
342 {pk_t::EXP_DST, {"exp_dst", "eltwise_exp_use_dst_for_bwd"},
343 dnnl_eltwise_exp_use_dst_for_bwd},
344 {pk_t::GELU_ERF, {"gelu_erf", "eltwise_gelu_erf"},
345 dnnl_eltwise_gelu_erf},
346 {pk_t::GELU_TANH, {"gelu_tanh", "eltwise_gelu_tanh"},
347 dnnl_eltwise_gelu_tanh},
348 {pk_t::HARDSIGMOID, {"hardsigmoid", "eltwise_hardsigmoid"},
349 dnnl_eltwise_hardsigmoid},
350 {pk_t::HARDSWISH, {"hardswish", "eltwise_hardswish"},
351 dnnl_eltwise_hardswish},
352 {pk_t::LINEAR, {"linear", "eltwise_linear"}, dnnl_eltwise_linear},
353 {pk_t::LOG, {"log", "eltwise_log"}, dnnl_eltwise_log},
354 {pk_t::LOGISTIC, {"logistic", "eltwise_logistic"},
355 dnnl_eltwise_logistic},
356 {pk_t::LOGISTIC_DST,
357 {"logistic_dst", "eltwise_logistic_use_dst_for_bwd"},
358 dnnl_eltwise_logistic_use_dst_for_bwd},
359 {pk_t::MISH, {"mish", "eltwise_mish"}, dnnl_eltwise_mish},
360 {pk_t::POW, {"pow", "eltwise_pow"}, dnnl_eltwise_pow},
361 {pk_t::RELU, {"relu", "eltwise_relu"}, dnnl_eltwise_relu},
362 {pk_t::RELU_DST, {"relu_dst", "eltwise_relu_use_dst_for_bwd"},
363 dnnl_eltwise_relu_use_dst_for_bwd},
364 {pk_t::ROUND, {"round", "eltwise_round"}, dnnl_eltwise_round},
365 {pk_t::SQRT, {"sqrt", "eltwise_sqrt"}, dnnl_eltwise_sqrt},
366 {pk_t::SQRT_DST, {"sqrt_dst", "eltwise_sqrt_use_dst_for_bwd"},
367 dnnl_eltwise_sqrt_use_dst_for_bwd},
368 {pk_t::SQUARE, {"square", "eltwise_square"}, dnnl_eltwise_square},
369 {pk_t::SRELU, {"soft_relu", "eltwise_soft_relu", "srelu"},
370 dnnl_eltwise_soft_relu},
371 {pk_t::SWISH, {"swish", "eltwise_swish"}, dnnl_eltwise_swish},
372 {pk_t::TANH, {"tanh", "eltwise_tanh"}, dnnl_eltwise_tanh},
373 {pk_t::TANH_DST, {"tanh_dst", "eltwise_tanh_use_dst_for_bwd"},
374 dnnl_eltwise_tanh_use_dst_for_bwd},
375 {pk_t::ELTWISE_END, {"eltwise_undef"}, dnnl_alg_kind_undef},
376 // binary
377 {pk_t::BINARY_START, {"binary_undef"}, dnnl_alg_kind_undef},
378 {pk_t::ADD, {"add", "binary_add"}, dnnl_binary_add},
379 {pk_t::DIV, {"div", "binary_div"}, dnnl_binary_div},
380 {pk_t::EQ, {"eq", "binary_eq"}, dnnl_binary_eq},
381 {pk_t::GE, {"ge", "binary_ge"}, dnnl_binary_ge},
382 {pk_t::GT, {"gt", "binary_gt"}, dnnl_binary_gt},
383 {pk_t::LE, {"le", "binary_le"}, dnnl_binary_le},
384 {pk_t::LT, {"lt", "binary_lt"}, dnnl_binary_lt},
385 {pk_t::MAX, {"max", "binary_max"}, dnnl_binary_max},
386 {pk_t::MIN, {"min", "binary_min"}, dnnl_binary_min},
387 {pk_t::MUL, {"mul", "binary_mul"}, dnnl_binary_mul},
388 {pk_t::NE, {"ne", "binary_ne"}, dnnl_binary_ne},
389 {pk_t::SUB, {"sub", "binary_sub"}, dnnl_binary_sub},
390 {pk_t::BINARY_END, {"binary_undef"}, dnnl_alg_kind_undef},
391 // prelu
392 {pk_t::PRELU, {"prelu"}, dnnl_alg_kind_undef},
393 // guard entry
394 {pk_t::KIND_TOTAL, {"kind_undef"}, dnnl_alg_kind_undef}};
395
396pk_t attr_t::post_ops_t::str2kind(const std::string &str) {
397 std::string s(str);
398 // string::operator== is lexicographical, case matters
399 std::transform(s.begin(), s.end(), s.begin(), ::tolower);
400 for (const auto &e : kind_table) {
401 for (const auto &name : e.kind_names) {
402 if (s == name) return e.kind;
403 }
404 }
405 BENCHDNN_PRINT(0, "%s\'%s\' %s\n", "Error: ", str.c_str(),
406 "kind of post operation entry was not recognized.");
407
408 const auto table_size = sizeof(kind_table) / sizeof(*kind_table);
409 return kind_table[table_size - 1].kind;
410}
411
412const char *attr_t::post_ops_t::kind2str(pk_t kind) {
413 for (const auto &e : kind_table) {
414 if (e.kind == kind) return e.kind_names[0].c_str();
415 }
416 assert(!"unknown attr::post_ops::kind");
417 const auto table_size = sizeof(kind_table) / sizeof(*kind_table);
418 return kind_table[table_size - 1].kind_names[0].c_str();
419}
420
421dnnl_alg_kind_t attr_t::post_ops_t::kind2dnnl_kind(pk_t kind) {
422 for (const auto &e : kind_table) {
423 if (e.kind == kind) return e.dnnl_kind;
424 }
425 assert(!"unknown attr::post_ops::kind");
426 const auto table_size = sizeof(kind_table) / sizeof(*kind_table);
427 return kind_table[table_size - 1].dnnl_kind;
428}
429
430std::vector<std::pair<int, int>> attr_t::post_ops_t::get_po_masks() const {
431 std::vector<std::pair<int, int>> v_masks;
432 for (int idx = 0; idx < len(); ++idx) {
433 const auto &e = this->entry[idx];
434 policy_t policy = policy_t::COMMON;
435 int arg = BENCHDNN_DNNL_ARG_UNDEF;
436 if (e.is_binary_kind()) {
437 policy = e.binary.policy;
438 arg = DNNL_ARG_SRC_1;
439 } else if (e.is_prelu_kind()) {
440 policy = e.prelu.policy;
441 arg = DNNL_ARG_WEIGHTS;
442 } else
443 continue;
444
445 const auto mask = attr_t::get_default_mask(policy);
446 v_masks.emplace_back(std::make_pair(
447 DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | arg, mask));
448 }
449 return v_masks;
450}
451
452int attr_t::post_ops_t::from_str(const std::string &s) {
453 *this = post_ops_t();
454 if (s.empty()) return OK;
455
456 size_t start_pos = 0;
457 while (start_pos != std::string::npos) {
458 auto subs = parser::get_substr(s, start_pos, '+');
459 size_t subs_pos = 0;
460
461 auto kind = str2kind(parser::get_substr(subs, subs_pos, ':'));
462 if (kind == KIND_TOTAL) return FAIL;
463
464 entry.emplace_back(kind);
465 if (subs_pos == std::string::npos) continue;
466 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
467
468 auto &e = entry.back();
469 if (e.is_sum_kind()) {
470 e.sum.scale = std::stof(parser::get_substr(subs, subs_pos, ':'));
471 if (subs_pos == std::string::npos) continue;
472 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
473
474 auto zp_str = parser::get_substr(subs, subs_pos, ':');
475 e.sum.zero_point = std::stoi(zp_str);
476 if (subs_pos == std::string::npos) continue;
477 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
478 if (std::to_string(e.sum.zero_point) != zp_str) return FAIL;
479
480 e.sum.dt = str2dt(parser::get_substr(subs, subs_pos, ':').c_str());
481 // sum dt, if specified, should be defined
482 if (e.sum.dt == dnnl_data_type_undef) return FAIL;
483 } else if (e.is_convolution_kind()) {
484 if (kind == DW) {
485 // `DW` has input of `dw:kXsYpZ`, while rest have `dw_k3sXp1`.
486 const auto str_dw_params
487 = parser::get_substr(subs, subs_pos, ':');
488 size_t pos = 0, idx = 0;
489
490 pos += idx;
491 if (str_dw_params[pos] != 'k') return FAIL;
492 e.convolution.kernel = std::stoi(&str_dw_params[++pos], &idx);
493
494 pos += idx;
495 if (str_dw_params[pos] != 's') return FAIL;
496 e.convolution.stride = std::stoi(&str_dw_params[++pos], &idx);
497
498 pos += idx;
499 if (str_dw_params[pos] != 'p') return FAIL;
500 e.convolution.padding = std::stoi(&str_dw_params[++pos]);
501
502 if (subs_pos == std::string::npos) continue;
503 }
504
505 e.convolution.dst_dt
506 = str2dt(parser::get_substr(subs, subs_pos, ':').c_str());
507 if (e.convolution.dst_dt == dnnl_data_type_undef) return FAIL;
508 if (subs_pos == std::string::npos) continue;
509 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
510
511 auto scale_str = parser::get_substr(subs, subs_pos, '+');
512 SAFE(e.convolution.wei_scale.from_str(scale_str), WARN);
513 size_t dst_scale_pos = 0;
514 for (int i = 0; i < 2; ++i)
515 dst_scale_pos = scale_str.find(":", dst_scale_pos + 1);
516 if (dst_scale_pos != std::string::npos) {
517 auto dst_scale_str = scale_str.substr(dst_scale_pos + 1);
518 SAFE(e.convolution.dst_scale.from_str(dst_scale_str), WARN);
519 }
520 } else if (e.is_eltwise_kind()) {
521 e.eltwise.alpha
522 = std::stof(parser::get_substr(subs, subs_pos, ':'));
523 if (subs_pos == std::string::npos) continue;
524 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
525
526 e.eltwise.beta = std::stof(parser::get_substr(subs, subs_pos, ':'));
527 if (subs_pos == std::string::npos) continue;
528 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
529 } else if (e.is_binary_kind()) {
530 e.binary.src1_dt
531 = str2dt(parser::get_substr(subs, subs_pos, ':').c_str());
532 if (e.binary.src1_dt == dnnl_data_type_undef) return FAIL;
533 if (subs_pos == std::string::npos) continue;
534 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
535
536 e.binary.policy
537 = str2policy(parser::get_substr(subs, subs_pos, ':'));
538 if (e.binary.policy == POLICY_TOTAL) return FAIL;
539 if (subs_pos == std::string::npos) continue;
540 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
541
542 e.binary.tag = parser::get_substr(subs, subs_pos, ':');
543 SAFE(check_tag(e.binary.tag), WARN);
544 } else if (e.is_prelu_kind()) {
545 e.prelu.policy
546 = str2policy(parser::get_substr(subs, subs_pos, ':'));
547 if (e.prelu.policy == POLICY_TOTAL) return FAIL;
548 if (subs_pos == std::string::npos) continue;
549 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
550 }
551 if (subs_pos == std::string::npos) continue;
552 if (subs_pos >= subs.size()) return FAIL; // to catch dangling ':'
553 }
554 return OK;
555}
556
557bool attr_t::is_def(bool skip_fpmath) const {
558 return scales.is_def() && zero_points.is_def() && post_ops.is_def()
559 && scratchpad_mode == dnnl_scratchpad_mode_library
560 && IMPLICATION(
561 !skip_fpmath, fpmath_mode == dnnl_fpmath_mode_strict);
562}
563
564int attr_t::post_ops_t::find(pk_t kind, int start, int stop) const {
565 if (stop == -1) stop = len();
566 stop = MIN2(stop, len());
567 for (int idx = start; idx < stop; ++idx)
568 if (entry[idx].kind == kind) return idx;
569 return -1;
570}
571
572bool attr_t::post_ops_t::entry_t::is_sum_kind() const {
573 return kind == SUM;
574}
575bool attr_t::post_ops_t::entry_t::is_convolution_kind() const {
576 return kind == DW || kind == DW_K3S1P1 || kind == DW_K3S2P1;
577}
578bool attr_t::post_ops_t::entry_t::is_eltwise_kind() const {
579 return kind > ELTWISE_START && kind < ELTWISE_END;
580}
581bool attr_t::post_ops_t::entry_t::is_binary_kind() const {
582 return kind > pk_t::BINARY_START && kind < pk_t::BINARY_END;
583}
584bool attr_t::post_ops_t::entry_t::is_prelu_kind() const {
585 return kind == PRELU;
586}
587
588int attr_t::post_ops_t::convolution_index() const {
589 for (int i = 0; i < len(); ++i) {
590 if (entry[i].is_convolution_kind()) return i;
591 }
592 return -1;
593}
594
595int attr_t::post_ops_t::eltwise_index() const {
596 for (int i = 0; i < len(); ++i) {
597 if (entry[i].is_eltwise_kind()) return i;
598 }
599 return -1;
600}
601
602int attr_t::post_ops_t::binary_index() const {
603 for (int i = 0; i < len(); ++i) {
604 if (entry[i].is_binary_kind()) return i;
605 }
606 return -1;
607}
608
609int attr_t::post_ops_t::prelu_index() const {
610 for (int i = 0; i < len(); ++i) {
611 if (entry[i].is_prelu_kind()) return i;
612 }
613 return -1;
614}
615
616std::ostream &operator<<(std::ostream &s, const policy_t &policy) {
617 s << attr_t::policy2str(policy);
618 return s;
619}
620
621std::ostream &operator<<(std::ostream &s, const attr_t::scale_t &scale) {
622 s << scale.policy << ":" << scale.scale;
623 if (scale.runtime) s << '*';
624 return s;
625}
626
627std::ostream &operator<<(
628 std::ostream &s, const attr_t::zero_points_t &zero_points) {
629 const char *delim = "";
630 for (const auto &point : zero_points.points) {
631 s << delim;
632 s << arg2str(point.first) << ":" << point.second.policy << ":"
633 << point.second.value;
634 if (point.second.runtime) s << '*';
635 delim = "+";
636 }
637
638 return s;
639}
640
641std::ostream &operator<<(std::ostream &s, const attr_t::arg_scales_t &scales) {
642 const char *delim = "";
643 for (const auto &v : scales.scales) {
644 if (!v.second.is_def()) {
645 s << delim;
646 s << arg2str(v.first) << ":" << v.second;
647 delim = "+";
648 }
649 }
650 return s;
651}
652
653std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t::kind_t &k) {
654 s << attr_t::post_ops_t::kind2str(k);
655 return s;
656}
657
658std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t &post_ops) {
659 for (int idx = 0; idx < post_ops.len(); ++idx) {
660 if (idx > 0) s << "+";
661
662 const auto &e = post_ops.entry[idx];
663 s << e.kind;
664
665 if (e.is_sum_kind()) {
666 if (e.sum.scale != 1.0f || e.sum.zero_point != 0
667 || e.sum.dt != dnnl_data_type_undef)
668 s << ":" << e.sum.scale;
669 if (e.sum.zero_point != 0 || e.sum.dt != dnnl_data_type_undef)
670 s << ":" << e.sum.zero_point;
671 if (e.sum.dt != dnnl_data_type_undef) s << ":" << e.sum.dt;
672 } else if (e.is_convolution_kind()) {
673 if (e.kind == pk_t::DW) {
674 s << ":k" << e.convolution.kernel << "s" << e.convolution.stride
675 << "p" << e.convolution.padding;
676 }
677 const auto &c_ws = e.convolution.wei_scale;
678 const auto &c_ds = e.convolution.dst_scale;
679 if (e.convolution.dst_dt != dnnl_f32 || !c_ws.is_def()
680 || !c_ds.is_def())
681 s << ":" << e.convolution.dst_dt;
682 if (!c_ws.is_def() || !c_ds.is_def()) s << ":" << c_ws;
683 if (!c_ds.is_def()) s << ":" << c_ds;
684 } else if (e.is_eltwise_kind()) {
685 if (e.eltwise.beta != 0.f)
686 s << ":" << e.eltwise.alpha << ":" << e.eltwise.beta;
687 else if (e.eltwise.alpha != 0.f)
688 s << ":" << e.eltwise.alpha;
689 } else if (e.is_binary_kind()) {
690 s << ":" << e.binary.src1_dt;
691 if (e.binary.policy != policy_t::COMMON) {
692 s << ":" << e.binary.policy;
693 if (attr_t::get_default_mask(e.binary.policy) >= 4)
694 s << ":" << e.binary.tag;
695 }
696 } else if (e.is_prelu_kind()) {
697 if (e.prelu.policy != policy_t::COMMON) {
698 s << ":" << e.prelu.policy;
699 }
700 } else {
701 assert(!"unknown kind");
702 s << "unknown_kind";
703 }
704 }
705
706 return s;
707}
708
709std::ostream &operator<<(std::ostream &s, dnnl_scratchpad_mode_t sm) {
710 s << scratchpad_mode2str(sm);
711 return s;
712}
713
714std::ostream &operator<<(std::ostream &s, dnnl_fpmath_mode_t fm) {
715 s << fpmath_mode2str(fm);
716 return s;
717}
718
719std::ostream &operator<<(std::ostream &s, const attr_t &attr) {
720 if (!attr.is_def()) {
721 if (!attr.scales.is_def()) s << "--attr-scales=" << attr.scales << " ";
722 if (!attr.zero_points.is_def())
723 s << "--attr-zero-points=" << attr.zero_points << " ";
724 if (!attr.post_ops.is_def())
725 s << "--attr-post-ops=" << attr.post_ops << " ";
726 if (attr.scratchpad_mode != dnnl_scratchpad_mode_library)
727 s << "--attr-scratchpad=" << attr.scratchpad_mode << " ";
728 if (attr.fpmath_mode != dnnl_fpmath_mode_strict)
729 s << "--attr-fpmath=" << attr.fpmath_mode << " ";
730 }
731 return s;
732}
733
734std::ostream &operator<<(std::ostream &s, bench_mode_t mode) {
735 if (is_bench_mode(RUN) && !(is_bench_mode(CORR) || is_bench_mode(PERF)))
736 s << "R";
737 if (is_bench_mode(CORR)) s << "C";
738 if (is_bench_mode(PERF)) s << "P";
739 if (is_bench_mode(LIST)) s << "L";
740 if (is_bench_mode(PROF)) s << "O";
741 return s;
742}
743
744std::ostream &operator<<(std::ostream &s, memory_kind_ext_t memory_kind) {
745 switch (memory_kind) {
746 case memory_kind_ext_t::usm: s << "usm"; break;
747 case memory_kind_ext_t::buffer: s << "buffer"; break;
748 case memory_kind_ext_t::usm_device: s << "usm_device"; break;
749 case memory_kind_ext_t::usm_shared: s << "usm_shared"; break;
750 default: assert(!"unexpected"); break;
751 }
752 return s;
753}
754
755std::ostream &dump_global_params(std::ostream &s) {
756 s << "--" << driver_name << " ";
757 if (canonical) s << "--canonical=" << bool2str(canonical) << " ";
758 if (canonical || engine_tgt_kind != dnnl_cpu) {
759 s << "--engine=" << engine_tgt_kind;
760 if (engine_index != 0) s << ":" << engine_index;
761 s << " ";
762 }
763 if (canonical || fast_ref_gpu != true)
764 s << "--fast-ref-gpu=" << bool2str(fast_ref_gpu) << " ";
765 if (!skip_impl.empty()) s << "--skip-impl=" << skip_impl << " ";
766 if (canonical || mem_check != true)
767 s << "--mem-check=" << bool2str(mem_check) << " ";
768 if (canonical || allow_enum_tags_only != true)
769 s << "--allow-enum-tags-only=" << bool2str(allow_enum_tags_only) << " ";
770 if (canonical || hints.get() != isa_hints_t::none)
771 s << "--cpu-isa-hints=" << isa_hints_t::hints2str(hints) << " ";
772 if (canonical || bench_mode != CORR) s << "--mode=" << bench_mode << " ";
773 if (canonical || attr_same_pd_check != false)
774 s << "--attr-same-pd-check=" << bool2str(attr_same_pd_check) << " ";
775#if defined(DNNL_WITH_SYCL) || DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
776 if (canonical || memory_kind != default_memory_kind)
777 s << "--memory-kind=" << memory_kind << " ";
778#endif
779
780 return s;
781}
782
783dnnl_engine_kind_t str2engine_kind(const char *str) {
784 const char *param = "cpu";
785 if (!strncasecmp(param, str, strlen(param))) return dnnl_cpu;
786
787 param = "gpu";
788 if (!strncasecmp(param, str, strlen(param))) return dnnl_gpu;
789
790 assert(!"not expected");
791 return dnnl_cpu;
792}
793
794dnnl_scratchpad_mode_t str2scratchpad_mode(const char *str) {
795 const char *param = "library";
796 if (!strncasecmp(param, str, strlen(param)))
797 return dnnl_scratchpad_mode_library;
798
799 param = "user";
800 if (!strncasecmp(param, str, strlen(param)))
801 return dnnl_scratchpad_mode_user;
802
803 assert(!"not expected");
804 return dnnl_scratchpad_mode_library;
805}
806
807dnnl_fpmath_mode_t str2fpmath_mode(const char *str) {
808 if (std::strcmp(str, "") == 0) {
809 dnnl_fpmath_mode_t ret;
810 dnnl_get_default_fpmath_mode(&ret);
811 return ret;
812 }
813
814#define CASE(fpm) \
815 param = #fpm; \
816 if (!strncasecmp(param, str, strlen(param))) return dnnl_fpmath_mode_##fpm;
817
818 const char *param;
819
820 CASE(strict);
821 CASE(bf16);
822 CASE(f16);
823 CASE(tf32);
824 CASE(any);
825
826 assert(!"not expected");
827 return dnnl_fpmath_mode_strict;
828
829#undef CASE
830}
831
832void attr_args_t::prepare_scales(const attr_t &attr, int arg, const void *vals,
833 int64_t count, int mask) {
834 insert(arg, vals, count, mask, attr.scales.get(arg).runtime);
835}
836
837struct post_ops_rhs_tensor_entry_t {
838 dnnl_data_type_t dt;
839 policy_t policy;
840 std::string tag;
841 int arg_attr_mask;
842};
843
844namespace {
845
846post_ops_rhs_tensor_entry_t get_po_rhs_tensor_entry(
847 const attr_t::post_ops_t::entry_t &entry) {
848 if (entry.is_prelu_kind()) {
849 const auto &prelu = entry.prelu;
850 return {dnnl_f32, prelu.policy, tag::axb, DNNL_ARG_WEIGHTS};
851 } else if (entry.is_binary_kind()) {
852 const auto &binary = entry.binary;
853 return {binary.src1_dt, binary.policy, binary.tag, DNNL_ARG_SRC_1};
854 }
855
856 return post_ops_rhs_tensor_entry_t {};
857}
858
859} // namespace
860
861int attr_args_t::prepare_post_ops_mds(
862 const attr_t &attr, int ndims, const dnnl_dims_t dims) {
863 const auto &po = attr.post_ops;
864 // iterate over all post ops and prepare md for each binary
865 for (int idx = 0; idx < po.len(); ++idx) {
866 const auto &e = po.entry[idx];
867 if (e.is_binary_kind() || e.is_prelu_kind()) {
868
869 const auto po_rhs_tensor_entry = get_po_rhs_tensor_entry(e);
870 const int mask
871 = attr_t::get_default_mask(po_rhs_tensor_entry.policy);
872
873 // deduce binary, prelu dims based on input policy
874 dnnl_dims_t rhs_tensor_dims = {};
875 for (auto d = 0; d < ndims; ++d)
876 rhs_tensor_dims[d] = (!(mask & (1 << d))) ? 1 : dims[d];
877
878 auto rhs_tensor_desc = dnn_mem_t::init_md(ndims, rhs_tensor_dims,
879 po_rhs_tensor_entry.dt, po_rhs_tensor_entry.tag);
880 mds.emplace((DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx)
881 | po_rhs_tensor_entry.arg_attr_mask),
882 std::move(rhs_tensor_desc));
883 }
884 }
885
886 return OK;
887}
888
889void attr_args_t::prepare_dw_post_op(
890 const attr_t &attr, dnnl_data_type_t wei_dt, dnnl_data_type_t bia_dt) {
891 const int dw_idx = attr.post_ops.convolution_index();
892 if (dw_idx == -1) return;
893
894 dw_entry.wei_dt = wei_dt;
895 dw_entry.bia_dt = bia_dt;
896}
897
898dnnl_primitive_attr_t create_dnnl_attr(
899 const attr_t &attr, const attr_args_t &attr_args) {
900 dnnl_primitive_attr_t dnnl_attr = nullptr;
901 DNN_SAFE_V(dnnl_primitive_attr_create(&dnnl_attr));
902
903 if (!attr.scales.is_def()) {
904 const auto &as = attr.scales;
905 for (const auto &arg : as.scales) {
906 const int arg_name = arg.first;
907 if (as.is_def(arg_name)) continue;
908
909 if (arg_name == DNNL_ARG_WEIGHTS
910 && arg.second.policy == policy_t::PER_OC
911 && !attr_args.get(arg_name).is_def()) {
912 const auto &e = attr_args.get(arg_name);
913 // Only RT scales are supported.
914 SAFE_V(e.runtime ? OK : FAIL);
915 // Only common policy is supported in the library at this point
916 int mask = e.mask;
917
918 DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(
919 dnnl_attr, arg_name, mask));
920 } else {
921 const auto &e = arg.second;
922 // Only RT scales are supported.
923 SAFE_V(e.runtime ? OK : FAIL);
924 // Only common policy is supported in the library at this point
925 int mask = attr_t::get_default_mask(e.policy, arg_name);
926
927 DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(
928 dnnl_attr, arg_name, mask));
929 }
930 }
931 }
932
933 if (!attr.zero_points.is_def()) {
934 const auto &zp = attr.zero_points;
935 for (const auto &arg : zp.points) {
936 const auto arg_name = arg.first;
937 if (zp.is_def(arg_name)) continue;
938
939 const auto &e = arg.second;
940 // Only RT scales are supported.
941 SAFE_V(e.runtime ? OK : FAIL);
942 // Only common policy/single RT value are supported in the library
943 // at this point
944 int mask = attr_t::get_default_mask(e.policy);
945
946 DNN_SAFE_V(dnnl_primitive_attr_set_zero_points_mask(
947 dnnl_attr, arg_name, mask));
948 }
949 }
950
951 if (!attr.post_ops.is_def()) {
952 dnnl_post_ops_t ops;
953 DNN_SAFE_V(dnnl_post_ops_create(&ops));
954
955 const auto &po = attr.post_ops;
956 for (int idx = 0; idx < po.len(); ++idx) {
957 const auto &e = po.entry[idx];
958 if (e.is_sum_kind()) {
959 DNN_SAFE_V(dnnl_post_ops_append_sum(
960 ops, e.sum.scale, e.sum.zero_point, e.sum.dt));
961 } else if (e.is_convolution_kind()) {
962 const auto wei_dt = attr_args.get_dw_arg(DNNL_ARG_WEIGHTS);
963 const auto bia_dt = attr_args.get_dw_arg(DNNL_ARG_BIAS);
964
965 DNN_SAFE_V(dnnl_post_ops_append_dw(ops, wei_dt, bia_dt,
966 e.convolution.dst_dt, e.convolution.kernel,
967 e.convolution.stride, e.convolution.padding));
968
969 const auto &wei_policy = e.convolution.wei_scale.policy;
970 int wei_mask = attr_t::get_default_mask(
971 wei_policy, DNNL_ARG_WEIGHTS);
972 // dw conv always has group dim
973 if (wei_mask) wei_mask = 1 << wei_mask;
974 if (e.convolution.wei_scale.runtime)
975 DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(dnnl_attr,
976 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS,
977 wei_mask));
978
979 const auto &dst_policy = e.convolution.dst_scale.policy;
980 int dst_mask
981 = attr_t::get_default_mask(dst_policy, DNNL_ARG_DST);
982 if (e.convolution.dst_scale.runtime)
983 DNN_SAFE_V(dnnl_primitive_attr_set_scales_mask(dnnl_attr,
984 DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_DST, dst_mask));
985
986 } else if (e.is_eltwise_kind()) {
987 DNN_SAFE_V(dnnl_post_ops_append_eltwise(
988 ops, e.eltwise.alg, e.eltwise.alpha, e.eltwise.beta));
989 } else if (e.is_binary_kind()) {
990 const auto &src1_md = attr_args.get_md(
991 (DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1));
992 assert(query_md_ndims(src1_md) != 0);
993 DNN_SAFE_V(dnnl_post_ops_append_binary(
994 ops, e.binary.alg, src1_md));
995 } else if (e.is_prelu_kind()) {
996 const auto &policy = e.prelu.policy;
997 const auto mask = attr_t::get_default_mask(policy);
998 DNN_SAFE_V(dnnl_post_ops_append_prelu(ops, mask));
999 } else {
1000 assert(!"unknown attr::post_ops::kind");
1001 }
1002 }
1003 DNN_SAFE_V(dnnl_primitive_attr_set_post_ops(dnnl_attr, ops));
1004 auto c_ops = query_post_ops(dnnl_attr);
1005 SAFE_V(dnnl_post_ops_len(c_ops) == po.len() ? OK : FAIL);
1006
1007 DNN_SAFE_V(dnnl_post_ops_destroy(ops));
1008 }
1009
1010 DNN_SAFE_V(dnnl_primitive_attr_set_scratchpad_mode(
1011 dnnl_attr, attr.scratchpad_mode));
1012
1013 DNN_SAFE_V(
1014 dnnl_primitive_attr_set_fpmath_mode(dnnl_attr, attr.fpmath_mode));
1015
1016 return dnnl_attr;
1017}
1018
1019// Exception free version of std::stoi, sets idx to 0 and returns 0 in case of
1020// error.
1021static int stoi_safe(const std::string &s, size_t *idx) {
1022 if (s.empty() || !std::isdigit(s[0])) {
1023 *idx = 0;
1024 return 0;
1025 }
1026 return std::stoi(s, idx);
1027}
1028
1029static bool is_abc_tag(const std::string &tag) {
1030 if (tag == tag::undef || tag == tag::any) return true;
1031
1032 bool mask[DNNL_MAX_NDIMS] = {};
1033 for (auto &c : tag) {
1034 if (!std::isalpha(c)) continue;
1035 int idx = std::tolower(c) - 'a';
1036 if (idx < 0 || idx >= DNNL_MAX_NDIMS) return false;
1037 mask[idx] = true;
1038 }
1039 // Check there are no gaps, e.g. [1 1 1 1 0 0 ...].
1040 for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
1041 if (mask[i]) continue;
1042 for (int j = i + 1; j < DNNL_MAX_NDIMS; j++)
1043 if (mask[j]) return false;
1044 break;
1045 }
1046 return true;
1047}
1048
1049int check_abc_tag(const std::string &tag_, bool check_enum_tags_only) {
1050 if (tag_.empty()) return FAIL;
1051 if (!is_abc_tag(tag_)) return FAIL;
1052 if (check_enum_tags_only) {
1053 if (str2fmt_tag(tag_.c_str()) == dnnl_format_tag_last) return FAIL;
1054 return OK;
1055 }
1056
1057 enum class dim_state_t { undef = 0, upper, lower, lower_with_block };
1058 dim_state_t dim_states[DNNL_MAX_NDIMS] = {};
1059 bool in_inner_block = false;
1060 auto tag = tag_;
1061 while (!tag.empty()) {
1062 // Parse block size if presented.
1063 size_t idx;
1064 int block = stoi_safe(tag, &idx);
1065 if (block == 0 && idx != 0) return FAIL;
1066 if (idx == 0) block = 0;
1067 if (block > 0) in_inner_block = true;
1068
1069 // Move to the first position after the block.
1070 tag = tag.substr(idx);
1071 if (tag.empty()) return FAIL;
1072
1073 char c = tag[0];
1074 bool is_lower = ('a' <= c && c <= 'a' + DNNL_MAX_NDIMS - 1);
1075 bool is_upper = ('A' <= c && c <= 'A' + DNNL_MAX_NDIMS - 1);
1076 if (!is_lower && !is_upper) return FAIL;
1077
1078 // Uppercase cannot be with block.
1079 if (is_upper && block != 0) return FAIL;
1080 // Block sizes are required within inner block.
1081 if (block == 0 && in_inner_block) return FAIL;
1082
1083 // Check rules related to lowercase/uppercase/block order.
1084 int dim_idx = std::tolower(c) - 'a';
1085 dim_state_t prev_state = dim_states[dim_idx];
1086 dim_state_t cur_state = is_upper
1087 ? dim_state_t::upper
1088 : block != 0 ? dim_state_t::lower_with_block
1089 : dim_state_t::lower;
1090
1091 switch (cur_state) {
1092 case dim_state_t::upper:
1093 case dim_state_t::lower:
1094 // Letter without block must be the first.
1095 if (prev_state != dim_state_t::undef) return FAIL;
1096 break;
1097 case dim_state_t::lower_with_block:
1098 // Letter with block must be after uppercase or after a letter
1099 // with block.
1100 if (prev_state != dim_state_t::upper
1101 && prev_state != dim_state_t::lower_with_block)
1102 return FAIL;
1103 break;
1104 default: assert(!"not expected");
1105 }
1106
1107 // Update state, move to the next position.
1108 dim_states[dim_idx] = cur_state;
1109 tag = tag.substr(1);
1110 }
1111
1112 for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
1113 // Uppercase letter must be followed by lowercase.
1114 if (dim_states[i] == dim_state_t::upper) return FAIL;
1115
1116 // Ensure there are no gaps (e.g. acd).
1117 if (dim_states[i] == dim_state_t::undef) {
1118 for (int j = i + 1; j < DNNL_MAX_NDIMS; j++)
1119 if (dim_states[j] != dim_state_t::undef) return FAIL;
1120 break;
1121 }
1122 }
1123
1124 return OK;
1125}
1126
1127static std::string trim_letter(const std::string &tag_, char c) {
1128 auto tag = tag_;
1129 for (size_t pos = tag.find(c); pos != std::string::npos;
1130 pos = tag.find(c)) {
1131 tag.replace(pos, 1, "");
1132 if (pos == 0) return tag;
1133
1134 pos--;
1135 while (std::isdigit(tag[pos])) {
1136 tag.replace(pos, 1, "");
1137 if (pos == 0) break;
1138 pos--;
1139 }
1140 }
1141 return tag;
1142}
1143
1144// Tries to map a tag to an abc-tag according to a logical tag. For example:
1145// nchw -> abcd.
1146static std::string try_map_tag(
1147 const std::string &logical_tag, const std::string &tag, int *nmatched) {
1148 // Check if all the required letters are presented.
1149 for (auto &c : logical_tag) {
1150 if (std::toupper(c) == c
1151 && tag.find(std::tolower(c)) == std::string::npos)
1152 return {};
1153 }
1154
1155 // Check that all letters are known and assign indices to letters.
1156 int logical_indices[DNNL_MAX_NDIMS] = {};
1157 for (auto &c : tag) {
1158 if (!std::isalpha(c)) continue;
1159
1160 auto lower_pos = logical_tag.find(std::tolower(c));
1161 auto upper_pos = logical_tag.find(std::toupper(c));
1162 auto pos = (lower_pos == std::string::npos ? upper_pos : lower_pos);
1163 if (pos == std::string::npos) return {};
1164
1165 logical_indices[pos] = 1;
1166 }
1167
1168 for (int i = 0, idx = 0; i < (int)logical_tag.size(); i++) {
1169 if (logical_indices[i] == 0) continue;
1170 logical_indices[i] = idx++;
1171 }
1172
1173 (*nmatched)++;
1174 std::string mapped_tag = tag;
1175 for (int i = 0; i < (int)tag.size(); i++) {
1176 char c = tag[i];
1177 if (!std::isalpha(tag[i])) continue;
1178 auto pos = logical_tag.find(std::tolower(c));
1179 if (pos == std::string::npos) pos = logical_tag.find(std::toupper(c));
1180
1181 mapped_tag[i]
1182 = (char)(tag[i] - std::tolower(c) + 'a' + logical_indices[pos]);
1183 }
1184 return mapped_tag;
1185}
1186
1187// Maps a tag to an abc-tag.
1188static std::string map_tag_letters(const std::string &tag) {
1189 int nmatched = 0;
1190
1191 // Mapping rules:
1192 // - Uppercase letters are mandatory
1193 // - Lowercase letters are optional
1194 auto tag_goidhw = try_map_tag("GOIdhw", tag, &nmatched);
1195 auto tag_oidhw = try_map_tag("OIdhw", tag, &nmatched);
1196 auto tag_ncdhw = try_map_tag("NCdhw", tag, &nmatched);
1197 auto tag_tnc = try_map_tag("TNc", tag, &nmatched);
1198 auto tag_ldnc = try_map_tag("LDNC", tag, &nmatched);
1199 auto tag_ldigo = try_map_tag("LDigO", tag, &nmatched);
1200
1201 if (nmatched == 0) return tag;
1202 if (nmatched > 1) assert(!"Not expected: ambiguous tag.");
1203
1204 if (!tag_goidhw.empty()) return tag_goidhw;
1205 if (!tag_oidhw.empty()) return tag_oidhw;
1206 if (!tag_ncdhw.empty()) return tag_ncdhw;
1207 if (!tag_tnc.empty()) return tag_tnc;
1208 if (!tag_ldnc.empty()) return tag_ldnc;
1209 if (!tag_ldigo.empty()) return tag_ldigo;
1210
1211 return tag;
1212}
1213
1214std::string trim_tag(const std::string &tag, int ndims) {
1215 int mask = 0;
1216 for (int d = 0; d < ndims; d++) {
1217 mask += (1 << d);
1218 }
1219 return trim_tag_by_mask(tag, mask);
1220}
1221
1222std::string trim_tag_by_mask(const std::string &tag, int mask) {
1223 std::string trimmed_tag = tag;
1224 int ndims_saved = 0;
1225 for (char c = 'a', d = 0; c < 'a' + (char)(DNNL_MAX_NDIMS); c++, d++) {
1226 if (!(mask & (1 << d))) {
1227 trimmed_tag = trim_letter(trimmed_tag, c);
1228 trimmed_tag = trim_letter(trimmed_tag, std::toupper(c));
1229 } else {
1230 ndims_saved++;
1231 }
1232 }
1233
1234 // Mask may operate over non-consecutive dimensions. The piece below will
1235 // make trimmed_tag consist of consecutive dimensions starting from "a" or
1236 // "A". E.g., mask = 2 + 8 = 10, trimmed_tag will contain "b" and "d"
1237 // letters, and will be converted into one with "a" and "b".
1238 int mask_copy = mask;
1239 for (int i = 0; i < ndims_saved; i++) {
1240 int dist_to_a = 0;
1241 while (mask_copy % 2 == 0) {
1242 mask_copy /= 2;
1243 dist_to_a++;
1244 }
1245 mask_copy /= 2;
1246 if (dist_to_a == 0) continue;
1247
1248 for (size_t j = 0; j < trimmed_tag.size(); j++) {
1249 char str_j = trimmed_tag[j];
1250 if (std::isalpha(str_j) && std::tolower(str_j) > 'a' + i) {
1251 std::string rep_str(1, str_j - dist_to_a);
1252 trimmed_tag.replace(j, 1, rep_str);
1253 }
1254 }
1255 }
1256
1257 return trimmed_tag;
1258}
1259
1260std::string normalize_tag(const std::string &tag_, int ndims) {
1261 std::string tag = tag_;
1262 if (tag == tag::undef || tag == tag::any || ndims == 0) return tag;
1263 if (tag == tag::x) {
1264 if (ndims >= 0) assert(ndims == 1);
1265 return "a";
1266 }
1267
1268 // Handle meta-tags (abx, axb, etc).
1269 auto pos = tag.find("x");
1270 if (pos != std::string::npos) {
1271 // Non-grouped tags will start `x` from `c`, but grouped will most of
1272 // times start `x` from `d`.
1273 char start_x = 'c';
1274 for (char c = 'a' + DNNL_MAX_NDIMS - 1; c >= 'b'; c--) {
1275 if (tag.find(c) != std::string::npos) {
1276 start_x = c + 1;
1277 break;
1278 }
1279 }
1280 // Adjust ndims if they are not specified.
1281 int meta_ndims = (ndims == -1 ? (start_x - 'a' + 1) : ndims);
1282 std::string tail;
1283 for (int i = 0; i < meta_ndims - (start_x - 'a'); i++)
1284 tail += (start_x + i);
1285 return trim_tag(tag.replace(pos, 1, tail), meta_ndims);
1286 }
1287
1288 return map_tag_letters(tag);
1289}
1290
1291int check_tag(const std::string &tag_, bool check_enum_tags_only) {
1292 auto tag = normalize_tag(tag_);
1293 if (tag == tag::undef || tag == tag::any) return OK;
1294 return check_abc_tag(tag, check_enum_tags_only);
1295}
1296
1297void maybe_scale(const attr_t &attr, float &d, const float *scales, int64_t c,
1298 int arg, bool opposite_scale) {
1299 if (attr.scales.is_def()) return;
1300
1301 const auto &e = attr.scales.get(arg);
1302 if (!e.is_def()) {
1303 int64_t idx = e.policy == policy_t::COMMON ? 0 : c;
1304 if (opposite_scale)
1305 d /= scales[idx];
1306 else
1307 d *= scales[idx];
1308 }
1309}
1310
1311void maybe_zero_point(const attr_t &attr, float &d, const int32_t *zero_points,
1312 int64_t c, int arg, bool opposite_zero_point) {
1313 if (attr.zero_points.is_def()) return;
1314
1315 const auto &e = attr.zero_points.get(arg);
1316 if (!e.is_def()) {
1317 const int idx = e.policy == policy_t::COMMON ? 0 : c;
1318 const int zp_sign = opposite_zero_point ? -1 : 1;
1319 d -= zp_sign * zero_points[idx];
1320 }
1321}
1322
1323float compute_eltwise_fwd(pk_t kind, float src, float alpha, float beta) {
1324 // don't compute on nan, propagate it
1325 if (std::isnan(src)) return NAN;
1326
1327 using namespace dnnl::impl::math;
1328
1329 switch (kind) {
1330 case pk_t::RELU: return relu_fwd(src, alpha);
1331 case pk_t::TANH: return tanh_fwd(src);
1332 case pk_t::ELU: return elu_fwd(src, alpha);
1333 case pk_t::SQUARE: return square_fwd(src);
1334 case pk_t::ABS: return abs_fwd(src);
1335 case pk_t::SQRT: return sqrt_fwd(src);
1336 case pk_t::LINEAR: return linear_fwd(src, alpha, beta);
1337 case pk_t::SRELU: return soft_relu_fwd(src, alpha);
1338 case pk_t::MISH: return mish_fwd(src);
1339 case pk_t::LOGISTIC: return logistic_fwd(src);
1340 case pk_t::EXP: return exp_fwd(src);
1341 case pk_t::GELU_TANH: return gelu_tanh_fwd(src);
1342 case pk_t::SWISH: return swish_fwd(src, alpha);
1343 case pk_t::LOG: return log_fwd(src);
1344 case pk_t::CLIP: return clip_fwd(src, alpha, beta);
1345 case pk_t::CLIP_V2: return clip_v2_fwd(src, alpha, beta);
1346 case pk_t::POW: return pow_fwd(src, alpha, beta);
1347 case pk_t::GELU_ERF: return gelu_erf_fwd(src);
1348 case pk_t::ROUND: return round_fwd(src);
1349 case pk_t::HARDSWISH: return hardswish_fwd(src, alpha, beta);
1350 case pk_t::HARDSIGMOID: return hardsigmoid_fwd(src, alpha, beta);
1351 case pk_t::RELU_DST: return relu_fwd(src, alpha);
1352 case pk_t::TANH_DST: return tanh_fwd(src);
1353 case pk_t::ELU_DST: return elu_fwd(src, alpha);
1354 case pk_t::SQRT_DST: return sqrt_fwd(src);
1355 case pk_t::LOGISTIC_DST: return logistic_fwd(src);
1356 case pk_t::EXP_DST: return exp_fwd(src);
1357 case pk_t::CLIP_V2_DST: return clip_v2_fwd(src, alpha, beta);
1358
1359 default: assert(!"unknown attr::post_ops::kind");
1360 };
1361 return NAN;
1362}
1363
1364float compute_eltwise_bwd(
1365 pk_t kind, float d_dst, float src, float alpha, float beta) {
1366 using namespace dnnl::impl::math;
1367
1368 switch (kind) {
1369 case pk_t::RELU: return relu_bwd(d_dst, src, alpha);
1370 case pk_t::TANH: return tanh_bwd(d_dst, src);
1371 case pk_t::ELU: return elu_bwd(d_dst, src, alpha);
1372 case pk_t::SQUARE: return square_bwd(d_dst, src);
1373 case pk_t::ABS: return abs_bwd(d_dst, src);
1374 case pk_t::SQRT: return sqrt_bwd(d_dst, src);
1375 case pk_t::LINEAR: return linear_bwd(d_dst, src, alpha, beta);
1376 case pk_t::SRELU: return soft_relu_bwd(d_dst, src, alpha);
1377 case pk_t::MISH: return mish_bwd(d_dst, src);
1378 case pk_t::LOGISTIC: return logistic_bwd(d_dst, src);
1379 case pk_t::EXP: return exp_bwd(d_dst, src);
1380 case pk_t::GELU_TANH: return gelu_tanh_bwd(d_dst, src);
1381 case pk_t::SWISH: return swish_bwd(d_dst, src, alpha);
1382 case pk_t::LOG: return log_bwd(d_dst, src);
1383 case pk_t::CLIP: return clip_bwd(d_dst, src, alpha, beta);
1384 case pk_t::CLIP_V2: return clip_v2_bwd(d_dst, src, alpha, beta);
1385 case pk_t::POW: return pow_bwd(d_dst, src, alpha, beta);
1386 case pk_t::GELU_ERF: return gelu_erf_bwd(d_dst, src);
1387 case pk_t::HARDSWISH: return hardswish_bwd(d_dst, src, alpha, beta);
1388 case pk_t::HARDSIGMOID: return hardsigmoid_bwd(d_dst, src, alpha, beta);
1389
1390 case pk_t::RELU_DST: return relu_bwd_use_dst(d_dst, src, alpha);
1391 case pk_t::TANH_DST: return tanh_bwd_use_dst(d_dst, src);
1392 case pk_t::ELU_DST: return elu_bwd_use_dst(d_dst, src, alpha);
1393 case pk_t::SQRT_DST: return sqrt_bwd_use_dst(d_dst, src);
1394 case pk_t::LOGISTIC_DST: return logistic_bwd_use_dst(d_dst, src);
1395 case pk_t::EXP_DST: return exp_bwd_use_dst(d_dst, src);
1396 case pk_t::CLIP_V2_DST:
1397 return clip_v2_bwd_use_dst(d_dst, src, alpha, beta);
1398
1399 default: assert(!"unknown attr::post_ops::kind");
1400 }
1401 return NAN;
1402}
1403
1404float compute_binary(pk_t kind, float src0, float src1) {
1405 // don't compute on nan, propagate it
1406 if (std::isnan(src0) || std::isnan(src1)) return NAN;
1407
1408 if (kind == pk_t::ADD) {
1409 return src0 + src1;
1410 } else if (kind == pk_t::MUL) {
1411 return src0 * src1;
1412 } else if (kind == pk_t::MAX) {
1413 return MAX2(src0, src1);
1414 } else if (kind == pk_t::MIN) {
1415 return MIN2(src0, src1);
1416 } else if (kind == pk_t::DIV) {
1417 return src0 / src1;
1418 } else if (kind == pk_t::SUB) {
1419 return src0 - src1;
1420 } else if (kind == pk_t::GE) {
1421 return src0 >= src1;
1422 } else if (kind == pk_t::GT) {
1423 return src0 > src1;
1424 } else if (kind == pk_t::LE) {
1425 return src0 <= src1;
1426 } else if (kind == pk_t::LT) {
1427 return src0 < src1;
1428 } else if (kind == pk_t::EQ) {
1429 return src0 == src1;
1430 } else if (kind == pk_t::NE) {
1431 return src0 != src1;
1432 } else {
1433 assert(!"operation not supported!");
1434 }
1435 return NAN;
1436}
1437
1438void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
1439 const std::vector<float> &v_po_vals) {
1440 using namespace dnnl::impl::math;
1441
1442 auto it_po = v_po_vals.begin();
1443 const auto &po = attr.post_ops;
1444 for (int idx = 0; idx < po.len(); ++idx) {
1445 const auto &e = po.entry[idx];
1446
1447 if (e.is_sum_kind()) {
1448 val += e.sum.scale * (sum_val - e.sum.zero_point);
1449 } else if (e.is_convolution_kind()) {
1450 continue;
1451 } else if (e.is_eltwise_kind()) {
1452 const auto &a = e.eltwise.alpha;
1453 const auto &b = e.eltwise.beta;
1454 val = compute_eltwise_fwd(e.kind, val, a, b);
1455 } else if (e.is_binary_kind()) {
1456 val = compute_binary(e.kind, val, *it_po);
1457 it_po++;
1458 } else if (e.is_prelu_kind()) {
1459 val = val > 0 ? val : val * (*it_po);
1460 it_po++;
1461 }
1462 }
1463}
1464
1465void update_cpu_ref_attrs(attr_t &attr, dnnl_data_type_t new_dt) {
1466 auto &po = attr.post_ops;
1467 for (int idx = 0; idx < po.len(); ++idx) {
1468 auto &e = po.entry[idx];
1469 if (!e.is_binary_kind()) continue;
1470
1471 e.binary.src1_dt = new_dt;
1472 e.binary.tag = tag::abx; // Hardcoded in setup_binary_po as well.
1473 }
1474}
1475
1476#undef BENCHDNN_DNNL_ARG_UNDEF
1477