1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef DNN_TYPES_HPP |
18 | #define DNN_TYPES_HPP |
19 | |
20 | #include <stddef.h> |
21 | #include <stdlib.h> |
22 | #include <string.h> |
23 | |
24 | #include <iostream> |
25 | #include <map> |
26 | #include <memory> |
27 | #include <string> |
28 | #include <vector> |
29 | #include <unordered_map> |
30 | |
31 | #include "common.hpp" |
32 | #include "oneapi/dnnl/dnnl_types.h" |
33 | #include "utils/wrapper.hpp" |
34 | |
35 | namespace tag { |
36 | extern const char *x; |
37 | extern const char *abx; |
38 | extern const char *axb; |
39 | extern const char *any; |
40 | extern const char *undef; |
41 | } // namespace tag |
42 | |
43 | enum dir_t { |
44 | DIR_UNDEF = 0, |
45 | FLAG_DAT = 1, |
46 | FLAG_WEI = 2, |
47 | FLAG_BIA = 4, |
48 | FLAG_FWD = 32, |
49 | FLAG_BWD = 64, |
50 | FLAG_INF = 128, |
51 | FWD_D = FLAG_FWD + FLAG_DAT, |
52 | FWD_I = FLAG_FWD + FLAG_DAT + FLAG_INF, |
53 | FWD_B = FLAG_FWD + FLAG_DAT + FLAG_BIA, |
54 | BWD_D = FLAG_BWD + FLAG_DAT, |
55 | BWD_DW = FLAG_BWD + FLAG_DAT + FLAG_WEI, |
56 | BWD_W = FLAG_BWD + FLAG_WEI, |
57 | BWD_WB = FLAG_BWD + FLAG_WEI + FLAG_BIA, |
58 | }; |
59 | dir_t str2dir(const char *str); |
60 | |
61 | /* TODO: merge prop and dir_t (in favor of prop) */ |
62 | const char *prop2str(dnnl_prop_kind_t prop); |
63 | dnnl_prop_kind_t prop2prop_kind(dir_t dir); |
64 | |
65 | std::ostream &operator<<(std::ostream &s, dir_t dir); |
66 | std::ostream &operator<<(std::ostream &s, dnnl_data_type_t dt); |
67 | std::ostream &operator<<(std::ostream &s, dnnl_engine_kind_t ek); |
68 | template <typename T> |
69 | std::ostream &operator<<(std::ostream &s, const std::vector<T> &v) { |
70 | s << v[0]; |
71 | for (size_t d = 1; d < v.size(); ++d) |
72 | s << ":" << v[d]; |
73 | return s; |
74 | } |
75 | |
76 | enum data_kind_t { |
77 | SRC = 0, |
78 | WEI, |
79 | BIA, |
80 | DST, |
81 | ACC, |
82 | // bnorm, lnorm |
83 | SRC_1, |
84 | MEAN, |
85 | VAR, |
86 | SC, |
87 | SH, |
88 | // rnn |
89 | DST_ITER, |
90 | DST_ITER_C, |
91 | AUGRU_ATTENTION, |
92 | SRC_ITER, |
93 | SRC_ITER_C, |
94 | WEI_ITER, |
95 | WEI_PEEPHOLE, |
96 | WEI_PROJECTION, |
97 | |
98 | DAT_TOTAL, |
99 | }; |
100 | const char *data_kind2str(data_kind_t kind); |
101 | |
102 | struct attr_t { |
103 | // policy_t defines the way entity values will be applied to a tensor |
104 | enum policy_t { |
105 | COMMON = 0, // single value for each point in a tensor |
106 | // apply a single value per... |
107 | PER_OC, // channel (dims[1]) point |
108 | PER_DIM_0, // ... dims[0] point. |
109 | PER_DIM_1, // ... dims[1] point. |
110 | PER_DIM_01, // ... unique combination of dims[0] and dims[1] points. |
111 | PER_DIM_2, // ... dims[2] point. |
112 | PER_DIM_023, // ... combination of dims[0], dims[2], dims[3] points. |
113 | PER_DIM_23, // ... combination of dims[2] and dims[3] points. |
114 | PER_DIM_03, // ... combination of dims[0] and dims[3] points. |
115 | PER_DIM_3, // ... dims[3] point. |
116 | PER_TENSOR, // ... point in the tensor. |
117 | POLICY_TOTAL // guard |
118 | }; |
119 | |
120 | static policy_t str2policy(const std::string &str); |
121 | static const char *policy2str(policy_t policy); |
122 | static int get_default_mask(policy_t policy, int arg = DNNL_ARG_DST); |
123 | |
124 | struct scale_t { |
125 | scale_t(policy_t apolicy = COMMON, float ascale = 1., |
126 | bool aruntime = false) |
127 | : policy(apolicy), scale(ascale), runtime(aruntime) {} |
128 | |
129 | int from_str(const std::string &s); |
130 | |
131 | bool is_def() const { |
132 | return policy == COMMON && scale == 1. && runtime == false; |
133 | } |
134 | |
135 | policy_t policy = COMMON; |
136 | float scale = 1.; |
137 | bool runtime = false; |
138 | }; |
139 | |
140 | struct zero_points_t { |
141 | struct entry_t { |
142 | entry_t(policy_t apolicy = COMMON, int avalue = 0, |
143 | bool aruntime = false) |
144 | : policy(apolicy), value(avalue), runtime(aruntime) {} |
145 | |
146 | entry_t(const entry_t &other) |
147 | : policy(other.policy) |
148 | , value(other.value) |
149 | , runtime(other.runtime) {} |
150 | |
151 | bool is_def() const { |
152 | return policy == COMMON && value == 0 && runtime == false; |
153 | } |
154 | |
155 | policy_t policy = COMMON; |
156 | int value = 0; |
157 | bool runtime = false; |
158 | }; |
159 | |
160 | int from_str(const std::string &s); |
161 | |
162 | int operator[](int arg) const { return get(arg).value; } |
163 | bool runtime(int arg) const { return get(arg).runtime; } |
164 | |
165 | bool is_def(int arg) const { return get(arg).is_def(); } |
166 | bool is_def() const { return points.empty(); } |
167 | |
168 | void set(int arg, policy_t policy, int value, bool runtime) { |
169 | set(arg, entry_t(policy, value, runtime)); |
170 | } |
171 | void set(int arg, const entry_t &entry) { |
172 | if (!entry.is_def()) points[arg] = entry; |
173 | } |
174 | entry_t get(int arg) const { |
175 | const auto it = points.find(arg); |
176 | return it == points.end() ? entry_t() : it->second; |
177 | } |
178 | |
179 | std::unordered_map<int, entry_t>::const_iterator begin() const { |
180 | return points.begin(); |
181 | } |
182 | std::unordered_map<int, entry_t>::const_iterator end() const { |
183 | return points.end(); |
184 | } |
185 | |
186 | zero_points_t() : points() {} // needed for debug icc190 build; |
187 | std::unordered_map<int, entry_t> points; |
188 | }; |
189 | |
190 | struct arg_scales_t { |
191 | void set(int arg, scale_t scale) { scales[arg] = scale; } |
192 | |
193 | scale_t get(int arg) const { |
194 | const auto &s = scales.find(arg); |
195 | return s == scales.end() ? scale_t() : s->second; |
196 | } |
197 | |
198 | bool is_def(int arg) const { return get(arg).is_def(); } |
199 | bool is_def() const { |
200 | bool def = true; |
201 | for (const auto &e : scales) { |
202 | def = def && is_def(e.first); |
203 | } |
204 | return def; |
205 | } |
206 | int from_str(const std::string &s); |
207 | |
208 | arg_scales_t() : scales() {} // needed for debug icc190 build; |
209 | |
210 | std::map<int, scale_t> scales; |
211 | }; |
212 | |
213 | struct post_ops_t { |
214 | enum kind_t { |
215 | // sum |
216 | SUM, |
217 | // depthwise convolution |
218 | DW, |
219 | DW_K3S1P1, |
220 | DW_K3S2P1, |
221 | // eltwise |
222 | ELTWISE_START, // a guard to check kind is eltwise |
223 | ABS, |
224 | CLIP, |
225 | CLIP_V2, |
226 | CLIP_V2_DST, |
227 | ELU, |
228 | ELU_DST, |
229 | EXP, |
230 | EXP_DST, |
231 | GELU_ERF, |
232 | GELU_TANH, |
233 | HARDSIGMOID, |
234 | HARDSWISH, |
235 | LINEAR, |
236 | LOG, |
237 | LOGISTIC, |
238 | LOGISTIC_DST, |
239 | MISH, |
240 | POW, |
241 | RELU, |
242 | RELU_DST, |
243 | ROUND, |
244 | SQRT, |
245 | SQRT_DST, |
246 | SQUARE, |
247 | SRELU, |
248 | SWISH, |
249 | TANH, |
250 | TANH_DST, |
251 | ELTWISE_END, // a guard to check kind is eltwise |
252 | // binary |
253 | BINARY_START, // a guard to check kind is binary |
254 | ADD, |
255 | DIV, |
256 | EQ, |
257 | GE, |
258 | GT, |
259 | LE, |
260 | LT, |
261 | MAX, |
262 | MIN, |
263 | MUL, |
264 | NE, |
265 | SUB, |
266 | BINARY_END, // a guard to check kind is binary |
267 | // prelu |
268 | PRELU, |
269 | // guard entry |
270 | KIND_TOTAL |
271 | }; |
272 | static kind_t str2kind(const std::string &str); |
273 | static const char *kind2str(kind_t kind); |
274 | static dnnl_alg_kind_t kind2dnnl_kind(kind_t kind); |
275 | |
276 | struct entry_t { |
277 | entry_t(kind_t akind) : kind(akind) { |
278 | if (is_sum_kind()) { |
279 | } else if (is_eltwise_kind()) { |
280 | eltwise.alg = kind2dnnl_kind(kind); |
281 | } else if (is_convolution_kind()) { |
282 | convolution.src_scale = scale_t(); |
283 | convolution.wei_scale = scale_t(); |
284 | convolution.dst_scale = scale_t(); |
285 | if (kind != DW) { |
286 | convolution.kernel = 3; |
287 | convolution.stride = kind == DW_K3S1P1 ? 1 : 2; |
288 | convolution.padding = 1; |
289 | } |
290 | } else if (is_binary_kind()) { |
291 | binary.alg = kind2dnnl_kind(kind); |
292 | } |
293 | } |
294 | |
295 | kind_t kind; |
296 | struct { |
297 | float scale = 1.f; |
298 | int32_t zero_point = 0; |
299 | dnnl_data_type_t dt = dnnl_data_type_undef; |
300 | } sum; |
301 | struct { |
302 | dnnl_alg_kind_t alg = dnnl_alg_kind_undef; |
303 | float alpha = 0.f; |
304 | float beta = 0.f; |
305 | } eltwise; |
306 | struct { |
307 | int kernel = 0; |
308 | int stride = 0; |
309 | int padding = 0; |
310 | dnnl_data_type_t dst_dt = dnnl_f32; |
311 | scale_t src_scale; |
312 | scale_t wei_scale; |
313 | scale_t dst_scale; |
314 | } convolution; |
315 | struct { |
316 | dnnl_alg_kind_t alg = dnnl_alg_kind_undef; |
317 | dnnl_data_type_t src1_dt = dnnl_data_type_undef; |
318 | policy_t policy = policy_t::COMMON; |
319 | std::string tag = tag::any; |
320 | } binary; |
321 | struct { |
322 | policy_t policy = policy_t::COMMON; |
323 | } prelu; |
324 | |
325 | bool is_sum_kind() const; |
326 | bool is_convolution_kind() const; |
327 | bool is_eltwise_kind() const; |
328 | bool is_binary_kind() const; |
329 | bool is_prelu_kind() const; |
330 | }; |
331 | |
332 | post_ops_t() : entry() {} |
333 | |
334 | int from_str(const std::string &s); |
335 | |
336 | int len() const { return (int)entry.size(); } |
337 | bool is_def() const { return len() == 0; } |
338 | |
339 | int find(kind_t kind, int start = 0, int stop = -1) const; |
340 | int eltwise_index() const; |
341 | int convolution_index() const; |
342 | int binary_index() const; |
343 | int prelu_index() const; |
344 | |
345 | std::vector<std::pair<int, int>> get_po_masks() const; |
346 | |
347 | std::vector<entry_t> entry; |
348 | }; |
349 | |
350 | attr_t() |
351 | : scratchpad_mode(dnnl_scratchpad_mode_library) |
352 | , fpmath_mode(dnnl_fpmath_mode_strict) {} |
353 | |
354 | template <typename First, typename... Rest> |
355 | void insert(const First &first, const Rest &... rest) { |
356 | this->insert(first); |
357 | if (sizeof...(rest) > 0) this->insert(rest...); |
358 | } |
359 | |
360 | void insert(const arg_scales_t &as) { this->scales = as; } |
361 | void insert(const zero_points_t &zp) { this->zero_points = zp; } |
362 | void insert(const post_ops_t &po) { this->post_ops = po; } |
363 | void insert(dnnl_scratchpad_mode_t sm) { this->scratchpad_mode = sm; } |
364 | void insert(dnnl_fpmath_mode_t fpm) { this->fpmath_mode = fpm; } |
365 | |
366 | arg_scales_t scales; |
367 | zero_points_t zero_points; |
368 | post_ops_t post_ops; |
369 | dnnl_scratchpad_mode_t scratchpad_mode; |
370 | dnnl_fpmath_mode_t fpmath_mode; |
371 | |
372 | bool is_def(bool skip_fpmath = false) const; |
373 | }; |
374 | |
375 | struct isa_hints_t { |
376 | enum cpu_hints_t { |
377 | // If DNNL_CPU_ISA_HINTS is set then use hints from there |
378 | // Otherwise no hints |
379 | none = 0x0, |
380 | // No CPU ISA specific hints |
381 | // Will override DNNL_CPU_ISA_HINTS if that is available too |
382 | no_hints = 0x1, |
383 | // Use prefer_ymm CPU ISA hint |
384 | // Will override DNNL_CPU_ISA_HINTS if that is available too |
385 | prefer_ymm = 0x2, |
386 | }; |
387 | |
388 | cpu_hints_t hints_; |
389 | isa_hints_t(cpu_hints_t hints) : hints_(hints) {} |
390 | |
391 | cpu_hints_t get() { return hints_; } |
392 | |
393 | static std::string hints2str(const isa_hints_t &isa_hints) { |
394 | switch (isa_hints.hints_) { |
395 | case none: return "none" ; |
396 | case no_hints: return "no_hints" ; |
397 | case prefer_ymm: return "prefer_ymm" ; |
398 | default: assert(!"unknown hint" ); return "unknown_hint" ; |
399 | } |
400 | } |
401 | |
402 | static isa_hints_t str2hints(const char *str) { |
403 | cpu_hints_t hints = none; |
404 | |
405 | if (strcasecmp(str, "prefer_ymm" ) == 0) |
406 | hints = prefer_ymm; |
407 | else if (strcasecmp(str, "no_hints" ) == 0) |
408 | hints = no_hints; |
409 | |
410 | return isa_hints_t(hints); |
411 | } |
412 | }; |
413 | |
414 | using policy_t = attr_t::policy_t; |
415 | |
416 | std::ostream &operator<<(std::ostream &s, const policy_t &policy); |
417 | std::ostream &operator<<(std::ostream &s, const attr_t::scale_t &scale); |
418 | std::ostream &operator<<( |
419 | std::ostream &s, const attr_t::zero_points_t &zero_points); |
420 | std::ostream &operator<<(std::ostream &s, const attr_t::arg_scales_t &scales); |
421 | std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t::kind_t &k); |
422 | std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t &post_ops); |
423 | std::ostream &operator<<(std::ostream &s, dnnl_scratchpad_mode_t sm); |
424 | std::ostream &operator<<(std::ostream &s, dnnl_fpmath_mode_t fm); |
425 | std::ostream &operator<<(std::ostream &s, const attr_t &attr); |
426 | |
427 | // A container for additional data and info, not available from user's input at |
428 | // parse time, but which are required to create the library attributes. |
429 | struct attr_args_t { |
430 | struct entry_t { |
431 | entry_t(const void *vals = NULL, int64_t count = 1, int mask = -1, |
432 | bool runtime = false) |
433 | : vals(vals), count(count), mask(mask), runtime(runtime) {} |
434 | |
435 | bool is_def() const { |
436 | return vals == NULL && count == 1 && mask == -1 && runtime == false; |
437 | } |
438 | |
439 | int64_t get_count(policy_t policy) const { |
440 | return (policy == policy_t::COMMON || runtime) ? 1 : count; |
441 | } |
442 | |
443 | int get_mask(policy_t policy) const { |
444 | return mask == -1 ? attr_t::get_default_mask(policy) : mask; |
445 | } |
446 | |
447 | const float *get_float_ptr() const { |
448 | return runtime ? &DNNL_RUNTIME_F32_VAL |
449 | : static_cast<const float *>(vals); |
450 | } |
451 | |
452 | const void *vals = NULL; |
453 | int64_t count = 1; |
454 | int mask = -1; |
455 | bool runtime = false; |
456 | }; |
457 | |
458 | struct dw_t { |
459 | dnnl_data_type_t wei_dt = dnnl_data_type_undef; |
460 | dnnl_data_type_t bia_dt = dnnl_data_type_undef; |
461 | }; |
462 | |
463 | attr_args_t() = default; |
464 | |
465 | void prepare_output_scales( |
466 | const attr_t &attr, const void *vals, int64_t count, int mask = -1); |
467 | |
468 | void prepare_scales(const attr_t &attr, int arg, const void *vals, |
469 | int64_t count, int mask = -1); |
470 | |
471 | int prepare_post_ops_mds( |
472 | const attr_t &attr, int ndims, const dnnl_dims_t dims); |
473 | |
474 | void prepare_dw_post_op(const attr_t &attr, dnnl_data_type_t wei_dt, |
475 | dnnl_data_type_t bia_dt); |
476 | |
477 | entry_t get(int arg) const { |
478 | const auto it = entries.find(arg); |
479 | return it == entries.end() ? entry_t() : it->second; |
480 | } |
481 | |
482 | dnnl_memory_desc_t get_md(int arg) const { |
483 | const auto it = mds.find(arg); |
484 | return it == mds.end() ? nullptr : (dnnl_memory_desc_t)it->second; |
485 | } |
486 | |
487 | dnnl_data_type_t get_dw_arg(int arg) const { |
488 | if (arg == DNNL_ARG_WEIGHTS) |
489 | return dw_entry.wei_dt; |
490 | else if (arg == DNNL_ARG_BIAS) |
491 | return dw_entry.bia_dt; |
492 | else { |
493 | assert(!"unsupported_argument" ); |
494 | return dnnl_data_type_undef; |
495 | } |
496 | } |
497 | |
498 | private: |
499 | void insert( |
500 | int arg, const void *vals, int64_t count, int mask, bool runtime) { |
501 | entries.insert( |
502 | std::make_pair(arg, entry_t(vals, count, mask, runtime))); |
503 | } |
504 | |
505 | std::map<int, entry_t> entries; |
506 | std::map<int, benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> mds; |
507 | dw_t dw_entry; // only single dw fusion is supported |
508 | }; |
509 | |
510 | std::ostream &dump_global_params(std::ostream &s); |
511 | |
512 | // Validates a tag/meta-tag. |
513 | int check_tag(const std::string &tag_, bool check_enum_tags_only = false); |
514 | |
515 | // Validates a tag in abc notation. |
516 | int check_abc_tag(const std::string &tag, bool check_enum_tags_only = false); |
517 | |
518 | // Removes extra dimensions from a tag according to ndims. |
519 | std::string trim_tag(const std::string &tag, int ndims); |
520 | // Removes extra dimensions from a tag according to mask. `ndims` version is a |
521 | // custom case of `mask` version, assuming that first `ndims` dimensions of mask |
522 | // are non-zero. |
523 | std::string trim_tag_by_mask(const std::string &tag, int mask); |
524 | |
525 | // Converts a tag/meta-tag to abc notation. |
526 | std::string normalize_tag(const std::string &tag, int ndims = -1); |
527 | |
528 | dnnl_primitive_attr_t create_dnnl_attr( |
529 | const attr_t &attr, const attr_args_t &attr_args); |
530 | |
531 | dnnl_engine_kind_t str2engine_kind(const char *str); |
532 | dnnl_scratchpad_mode_t str2scratchpad_mode(const char *str); |
533 | dnnl_fpmath_mode_t str2fpmath_mode(const char *str); |
534 | |
535 | void maybe_scale(const attr_t &attr, float &d, const float *scales, int64_t c, |
536 | int arg, bool opposite_scale = false); |
537 | void maybe_zero_point(const attr_t &attr, float &d, const int32_t *zero_points, |
538 | int64_t c, int arg, bool opposite_zero_point = false); |
539 | float compute_eltwise_fwd( |
540 | attr_t::post_ops_t::kind_t kind, float src, float alpha, float beta); |
541 | float compute_eltwise_bwd(attr_t::post_ops_t::kind_t kind, float d_dst, |
542 | float src, float alpha, float beta); |
543 | float compute_binary(attr_t::post_ops_t::kind_t kind, float src0, float src1); |
544 | void maybe_post_ops(const attr_t &attr, float &val, float sum_val, |
545 | const std::vector<float> &v_po_vals); |
546 | inline void maybe_post_ops( |
547 | const attr_t &attr, float &val, float sum_val = 0.f) { |
548 | maybe_post_ops(attr, val, sum_val, std::vector<float>()); |
549 | } |
550 | |
551 | // When using fast-ref-gpu option, reference expects everything to be in f32 |
552 | // data type and also no additional memories coming from runtime attributes. |
553 | // That's why we update all data types to f32 and remove all runtime arguments |
554 | // to makes them constant when possible. |
555 | void update_cpu_ref_attrs(attr_t &attr, dnnl_data_type_t new_dt = dnnl_f32); |
556 | #endif |
557 | |