1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef DNN_TYPES_HPP
18#define DNN_TYPES_HPP
19
20#include <stddef.h>
21#include <stdlib.h>
22#include <string.h>
23
24#include <iostream>
25#include <map>
26#include <memory>
27#include <string>
28#include <vector>
29#include <unordered_map>
30
31#include "common.hpp"
32#include "oneapi/dnnl/dnnl_types.h"
33#include "utils/wrapper.hpp"
34
35namespace tag {
36extern const char *x;
37extern const char *abx;
38extern const char *axb;
39extern const char *any;
40extern const char *undef;
41} // namespace tag
42
43enum dir_t {
44 DIR_UNDEF = 0,
45 FLAG_DAT = 1,
46 FLAG_WEI = 2,
47 FLAG_BIA = 4,
48 FLAG_FWD = 32,
49 FLAG_BWD = 64,
50 FLAG_INF = 128,
51 FWD_D = FLAG_FWD + FLAG_DAT,
52 FWD_I = FLAG_FWD + FLAG_DAT + FLAG_INF,
53 FWD_B = FLAG_FWD + FLAG_DAT + FLAG_BIA,
54 BWD_D = FLAG_BWD + FLAG_DAT,
55 BWD_DW = FLAG_BWD + FLAG_DAT + FLAG_WEI,
56 BWD_W = FLAG_BWD + FLAG_WEI,
57 BWD_WB = FLAG_BWD + FLAG_WEI + FLAG_BIA,
58};
59dir_t str2dir(const char *str);
60
61/* TODO: merge prop and dir_t (in favor of prop) */
62const char *prop2str(dnnl_prop_kind_t prop);
63dnnl_prop_kind_t prop2prop_kind(dir_t dir);
64
65std::ostream &operator<<(std::ostream &s, dir_t dir);
66std::ostream &operator<<(std::ostream &s, dnnl_data_type_t dt);
67std::ostream &operator<<(std::ostream &s, dnnl_engine_kind_t ek);
68template <typename T>
69std::ostream &operator<<(std::ostream &s, const std::vector<T> &v) {
70 s << v[0];
71 for (size_t d = 1; d < v.size(); ++d)
72 s << ":" << v[d];
73 return s;
74}
75
76enum data_kind_t {
77 SRC = 0,
78 WEI,
79 BIA,
80 DST,
81 ACC,
82 // bnorm, lnorm
83 SRC_1,
84 MEAN,
85 VAR,
86 SC,
87 SH,
88 // rnn
89 DST_ITER,
90 DST_ITER_C,
91 AUGRU_ATTENTION,
92 SRC_ITER,
93 SRC_ITER_C,
94 WEI_ITER,
95 WEI_PEEPHOLE,
96 WEI_PROJECTION,
97
98 DAT_TOTAL,
99};
100const char *data_kind2str(data_kind_t kind);
101
102struct attr_t {
103 // policy_t defines the way entity values will be applied to a tensor
104 enum policy_t {
105 COMMON = 0, // single value for each point in a tensor
106 // apply a single value per...
107 PER_OC, // channel (dims[1]) point
108 PER_DIM_0, // ... dims[0] point.
109 PER_DIM_1, // ... dims[1] point.
110 PER_DIM_01, // ... unique combination of dims[0] and dims[1] points.
111 PER_DIM_2, // ... dims[2] point.
112 PER_DIM_023, // ... combination of dims[0], dims[2], dims[3] points.
113 PER_DIM_23, // ... combination of dims[2] and dims[3] points.
114 PER_DIM_03, // ... combination of dims[0] and dims[3] points.
115 PER_DIM_3, // ... dims[3] point.
116 PER_TENSOR, // ... point in the tensor.
117 POLICY_TOTAL // guard
118 };
119
120 static policy_t str2policy(const std::string &str);
121 static const char *policy2str(policy_t policy);
122 static int get_default_mask(policy_t policy, int arg = DNNL_ARG_DST);
123
124 struct scale_t {
125 scale_t(policy_t apolicy = COMMON, float ascale = 1.,
126 bool aruntime = false)
127 : policy(apolicy), scale(ascale), runtime(aruntime) {}
128
129 int from_str(const std::string &s);
130
131 bool is_def() const {
132 return policy == COMMON && scale == 1. && runtime == false;
133 }
134
135 policy_t policy = COMMON;
136 float scale = 1.;
137 bool runtime = false;
138 };
139
140 struct zero_points_t {
141 struct entry_t {
142 entry_t(policy_t apolicy = COMMON, int avalue = 0,
143 bool aruntime = false)
144 : policy(apolicy), value(avalue), runtime(aruntime) {}
145
146 entry_t(const entry_t &other)
147 : policy(other.policy)
148 , value(other.value)
149 , runtime(other.runtime) {}
150
151 bool is_def() const {
152 return policy == COMMON && value == 0 && runtime == false;
153 }
154
155 policy_t policy = COMMON;
156 int value = 0;
157 bool runtime = false;
158 };
159
160 int from_str(const std::string &s);
161
162 int operator[](int arg) const { return get(arg).value; }
163 bool runtime(int arg) const { return get(arg).runtime; }
164
165 bool is_def(int arg) const { return get(arg).is_def(); }
166 bool is_def() const { return points.empty(); }
167
168 void set(int arg, policy_t policy, int value, bool runtime) {
169 set(arg, entry_t(policy, value, runtime));
170 }
171 void set(int arg, const entry_t &entry) {
172 if (!entry.is_def()) points[arg] = entry;
173 }
174 entry_t get(int arg) const {
175 const auto it = points.find(arg);
176 return it == points.end() ? entry_t() : it->second;
177 }
178
179 std::unordered_map<int, entry_t>::const_iterator begin() const {
180 return points.begin();
181 }
182 std::unordered_map<int, entry_t>::const_iterator end() const {
183 return points.end();
184 }
185
186 zero_points_t() : points() {} // needed for debug icc190 build;
187 std::unordered_map<int, entry_t> points;
188 };
189
190 struct arg_scales_t {
191 void set(int arg, scale_t scale) { scales[arg] = scale; }
192
193 scale_t get(int arg) const {
194 const auto &s = scales.find(arg);
195 return s == scales.end() ? scale_t() : s->second;
196 }
197
198 bool is_def(int arg) const { return get(arg).is_def(); }
199 bool is_def() const {
200 bool def = true;
201 for (const auto &e : scales) {
202 def = def && is_def(e.first);
203 }
204 return def;
205 }
206 int from_str(const std::string &s);
207
208 arg_scales_t() : scales() {} // needed for debug icc190 build;
209
210 std::map<int, scale_t> scales;
211 };
212
213 struct post_ops_t {
214 enum kind_t {
215 // sum
216 SUM,
217 // depthwise convolution
218 DW,
219 DW_K3S1P1,
220 DW_K3S2P1,
221 // eltwise
222 ELTWISE_START, // a guard to check kind is eltwise
223 ABS,
224 CLIP,
225 CLIP_V2,
226 CLIP_V2_DST,
227 ELU,
228 ELU_DST,
229 EXP,
230 EXP_DST,
231 GELU_ERF,
232 GELU_TANH,
233 HARDSIGMOID,
234 HARDSWISH,
235 LINEAR,
236 LOG,
237 LOGISTIC,
238 LOGISTIC_DST,
239 MISH,
240 POW,
241 RELU,
242 RELU_DST,
243 ROUND,
244 SQRT,
245 SQRT_DST,
246 SQUARE,
247 SRELU,
248 SWISH,
249 TANH,
250 TANH_DST,
251 ELTWISE_END, // a guard to check kind is eltwise
252 // binary
253 BINARY_START, // a guard to check kind is binary
254 ADD,
255 DIV,
256 EQ,
257 GE,
258 GT,
259 LE,
260 LT,
261 MAX,
262 MIN,
263 MUL,
264 NE,
265 SUB,
266 BINARY_END, // a guard to check kind is binary
267 // prelu
268 PRELU,
269 // guard entry
270 KIND_TOTAL
271 };
272 static kind_t str2kind(const std::string &str);
273 static const char *kind2str(kind_t kind);
274 static dnnl_alg_kind_t kind2dnnl_kind(kind_t kind);
275
276 struct entry_t {
277 entry_t(kind_t akind) : kind(akind) {
278 if (is_sum_kind()) {
279 } else if (is_eltwise_kind()) {
280 eltwise.alg = kind2dnnl_kind(kind);
281 } else if (is_convolution_kind()) {
282 convolution.src_scale = scale_t();
283 convolution.wei_scale = scale_t();
284 convolution.dst_scale = scale_t();
285 if (kind != DW) {
286 convolution.kernel = 3;
287 convolution.stride = kind == DW_K3S1P1 ? 1 : 2;
288 convolution.padding = 1;
289 }
290 } else if (is_binary_kind()) {
291 binary.alg = kind2dnnl_kind(kind);
292 }
293 }
294
295 kind_t kind;
296 struct {
297 float scale = 1.f;
298 int32_t zero_point = 0;
299 dnnl_data_type_t dt = dnnl_data_type_undef;
300 } sum;
301 struct {
302 dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
303 float alpha = 0.f;
304 float beta = 0.f;
305 } eltwise;
306 struct {
307 int kernel = 0;
308 int stride = 0;
309 int padding = 0;
310 dnnl_data_type_t dst_dt = dnnl_f32;
311 scale_t src_scale;
312 scale_t wei_scale;
313 scale_t dst_scale;
314 } convolution;
315 struct {
316 dnnl_alg_kind_t alg = dnnl_alg_kind_undef;
317 dnnl_data_type_t src1_dt = dnnl_data_type_undef;
318 policy_t policy = policy_t::COMMON;
319 std::string tag = tag::any;
320 } binary;
321 struct {
322 policy_t policy = policy_t::COMMON;
323 } prelu;
324
325 bool is_sum_kind() const;
326 bool is_convolution_kind() const;
327 bool is_eltwise_kind() const;
328 bool is_binary_kind() const;
329 bool is_prelu_kind() const;
330 };
331
332 post_ops_t() : entry() {}
333
334 int from_str(const std::string &s);
335
336 int len() const { return (int)entry.size(); }
337 bool is_def() const { return len() == 0; }
338
339 int find(kind_t kind, int start = 0, int stop = -1) const;
340 int eltwise_index() const;
341 int convolution_index() const;
342 int binary_index() const;
343 int prelu_index() const;
344
345 std::vector<std::pair<int, int>> get_po_masks() const;
346
347 std::vector<entry_t> entry;
348 };
349
350 attr_t()
351 : scratchpad_mode(dnnl_scratchpad_mode_library)
352 , fpmath_mode(dnnl_fpmath_mode_strict) {}
353
354 template <typename First, typename... Rest>
355 void insert(const First &first, const Rest &... rest) {
356 this->insert(first);
357 if (sizeof...(rest) > 0) this->insert(rest...);
358 }
359
360 void insert(const arg_scales_t &as) { this->scales = as; }
361 void insert(const zero_points_t &zp) { this->zero_points = zp; }
362 void insert(const post_ops_t &po) { this->post_ops = po; }
363 void insert(dnnl_scratchpad_mode_t sm) { this->scratchpad_mode = sm; }
364 void insert(dnnl_fpmath_mode_t fpm) { this->fpmath_mode = fpm; }
365
366 arg_scales_t scales;
367 zero_points_t zero_points;
368 post_ops_t post_ops;
369 dnnl_scratchpad_mode_t scratchpad_mode;
370 dnnl_fpmath_mode_t fpmath_mode;
371
372 bool is_def(bool skip_fpmath = false) const;
373};
374
375struct isa_hints_t {
376 enum cpu_hints_t {
377 // If DNNL_CPU_ISA_HINTS is set then use hints from there
378 // Otherwise no hints
379 none = 0x0,
380 // No CPU ISA specific hints
381 // Will override DNNL_CPU_ISA_HINTS if that is available too
382 no_hints = 0x1,
383 // Use prefer_ymm CPU ISA hint
384 // Will override DNNL_CPU_ISA_HINTS if that is available too
385 prefer_ymm = 0x2,
386 };
387
388 cpu_hints_t hints_;
389 isa_hints_t(cpu_hints_t hints) : hints_(hints) {}
390
391 cpu_hints_t get() { return hints_; }
392
393 static std::string hints2str(const isa_hints_t &isa_hints) {
394 switch (isa_hints.hints_) {
395 case none: return "none";
396 case no_hints: return "no_hints";
397 case prefer_ymm: return "prefer_ymm";
398 default: assert(!"unknown hint"); return "unknown_hint";
399 }
400 }
401
402 static isa_hints_t str2hints(const char *str) {
403 cpu_hints_t hints = none;
404
405 if (strcasecmp(str, "prefer_ymm") == 0)
406 hints = prefer_ymm;
407 else if (strcasecmp(str, "no_hints") == 0)
408 hints = no_hints;
409
410 return isa_hints_t(hints);
411 }
412};
413
414using policy_t = attr_t::policy_t;
415
416std::ostream &operator<<(std::ostream &s, const policy_t &policy);
417std::ostream &operator<<(std::ostream &s, const attr_t::scale_t &scale);
418std::ostream &operator<<(
419 std::ostream &s, const attr_t::zero_points_t &zero_points);
420std::ostream &operator<<(std::ostream &s, const attr_t::arg_scales_t &scales);
421std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t::kind_t &k);
422std::ostream &operator<<(std::ostream &s, const attr_t::post_ops_t &post_ops);
423std::ostream &operator<<(std::ostream &s, dnnl_scratchpad_mode_t sm);
424std::ostream &operator<<(std::ostream &s, dnnl_fpmath_mode_t fm);
425std::ostream &operator<<(std::ostream &s, const attr_t &attr);
426
427// A container for additional data and info, not available from user's input at
428// parse time, but which are required to create the library attributes.
429struct attr_args_t {
430 struct entry_t {
431 entry_t(const void *vals = NULL, int64_t count = 1, int mask = -1,
432 bool runtime = false)
433 : vals(vals), count(count), mask(mask), runtime(runtime) {}
434
435 bool is_def() const {
436 return vals == NULL && count == 1 && mask == -1 && runtime == false;
437 }
438
439 int64_t get_count(policy_t policy) const {
440 return (policy == policy_t::COMMON || runtime) ? 1 : count;
441 }
442
443 int get_mask(policy_t policy) const {
444 return mask == -1 ? attr_t::get_default_mask(policy) : mask;
445 }
446
447 const float *get_float_ptr() const {
448 return runtime ? &DNNL_RUNTIME_F32_VAL
449 : static_cast<const float *>(vals);
450 }
451
452 const void *vals = NULL;
453 int64_t count = 1;
454 int mask = -1;
455 bool runtime = false;
456 };
457
458 struct dw_t {
459 dnnl_data_type_t wei_dt = dnnl_data_type_undef;
460 dnnl_data_type_t bia_dt = dnnl_data_type_undef;
461 };
462
463 attr_args_t() = default;
464
465 void prepare_output_scales(
466 const attr_t &attr, const void *vals, int64_t count, int mask = -1);
467
468 void prepare_scales(const attr_t &attr, int arg, const void *vals,
469 int64_t count, int mask = -1);
470
471 int prepare_post_ops_mds(
472 const attr_t &attr, int ndims, const dnnl_dims_t dims);
473
474 void prepare_dw_post_op(const attr_t &attr, dnnl_data_type_t wei_dt,
475 dnnl_data_type_t bia_dt);
476
477 entry_t get(int arg) const {
478 const auto it = entries.find(arg);
479 return it == entries.end() ? entry_t() : it->second;
480 }
481
482 dnnl_memory_desc_t get_md(int arg) const {
483 const auto it = mds.find(arg);
484 return it == mds.end() ? nullptr : (dnnl_memory_desc_t)it->second;
485 }
486
487 dnnl_data_type_t get_dw_arg(int arg) const {
488 if (arg == DNNL_ARG_WEIGHTS)
489 return dw_entry.wei_dt;
490 else if (arg == DNNL_ARG_BIAS)
491 return dw_entry.bia_dt;
492 else {
493 assert(!"unsupported_argument");
494 return dnnl_data_type_undef;
495 }
496 }
497
498private:
499 void insert(
500 int arg, const void *vals, int64_t count, int mask, bool runtime) {
501 entries.insert(
502 std::make_pair(arg, entry_t(vals, count, mask, runtime)));
503 }
504
505 std::map<int, entry_t> entries;
506 std::map<int, benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>> mds;
507 dw_t dw_entry; // only single dw fusion is supported
508};
509
510std::ostream &dump_global_params(std::ostream &s);
511
512// Validates a tag/meta-tag.
513int check_tag(const std::string &tag_, bool check_enum_tags_only = false);
514
515// Validates a tag in abc notation.
516int check_abc_tag(const std::string &tag, bool check_enum_tags_only = false);
517
518// Removes extra dimensions from a tag according to ndims.
519std::string trim_tag(const std::string &tag, int ndims);
520// Removes extra dimensions from a tag according to mask. `ndims` version is a
521// custom case of `mask` version, assuming that first `ndims` dimensions of mask
522// are non-zero.
523std::string trim_tag_by_mask(const std::string &tag, int mask);
524
525// Converts a tag/meta-tag to abc notation.
526std::string normalize_tag(const std::string &tag, int ndims = -1);
527
528dnnl_primitive_attr_t create_dnnl_attr(
529 const attr_t &attr, const attr_args_t &attr_args);
530
531dnnl_engine_kind_t str2engine_kind(const char *str);
532dnnl_scratchpad_mode_t str2scratchpad_mode(const char *str);
533dnnl_fpmath_mode_t str2fpmath_mode(const char *str);
534
535void maybe_scale(const attr_t &attr, float &d, const float *scales, int64_t c,
536 int arg, bool opposite_scale = false);
537void maybe_zero_point(const attr_t &attr, float &d, const int32_t *zero_points,
538 int64_t c, int arg, bool opposite_zero_point = false);
539float compute_eltwise_fwd(
540 attr_t::post_ops_t::kind_t kind, float src, float alpha, float beta);
541float compute_eltwise_bwd(attr_t::post_ops_t::kind_t kind, float d_dst,
542 float src, float alpha, float beta);
543float compute_binary(attr_t::post_ops_t::kind_t kind, float src0, float src1);
544void maybe_post_ops(const attr_t &attr, float &val, float sum_val,
545 const std::vector<float> &v_po_vals);
546inline void maybe_post_ops(
547 const attr_t &attr, float &val, float sum_val = 0.f) {
548 maybe_post_ops(attr, val, sum_val, std::vector<float>());
549}
550
551// When using fast-ref-gpu option, reference expects everything to be in f32
552// data type and also no additional memories coming from runtime attributes.
553// That's why we update all data types to f32 and remove all runtime arguments
554// to makes them constant when possible.
555void update_cpu_ref_attrs(attr_t &attr, dnnl_data_type_t new_dt = dnnl_f32);
556#endif
557