1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_HPP |
18 | #define COMMON_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <float.h> |
22 | #include <math.h> |
23 | #include <stddef.h> |
24 | #include <stdio.h> |
25 | #include <stdlib.h> |
26 | #include <string.h> |
27 | |
28 | #include <bitset> |
29 | #include <cinttypes> |
30 | #include <functional> |
31 | #include <string> |
32 | #include <vector> |
33 | |
34 | #include "src/common/z_magic.hpp" |
35 | |
36 | #include "utils/timer.hpp" |
37 | |
38 | #define ABS(a) ((a) > 0 ? (a) : (-(a))) |
39 | |
40 | #define MIN2(a, b) ((a) < (b) ? (a) : (b)) |
41 | #define MAX2(a, b) ((a) > (b) ? (a) : (b)) |
42 | |
43 | #define MIN3(a, b, c) MIN2(a, MIN2(b, c)) |
44 | #define MAX3(a, b, c) MAX2(a, MAX2(b, c)) |
45 | |
46 | #define IMPLICATION(cause, effect) (!(cause) || !!(effect)) |
47 | |
48 | #if defined(_WIN32) && !defined(__GNUC__) |
49 | #define strncasecmp _strnicmp |
50 | #define strcasecmp _stricmp |
51 | #define __PRETTY_FUNCTION__ __FUNCSIG__ |
52 | #endif |
53 | |
54 | #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER) |
55 | #define collapse(x) |
56 | #endif |
57 | |
58 | #define IFMT "%" PRId64 |
59 | |
60 | #define OK 0 |
61 | #define FAIL 1 |
62 | |
63 | enum { CRIT = 1, WARN = 2 }; |
64 | |
65 | #define SAFE(f, s) \ |
66 | do { \ |
67 | int status__ = (f); \ |
68 | if (status__ != OK) { \ |
69 | if (s == CRIT || s == WARN) { \ |
70 | fprintf(stderr, "@@@ error [%s:%d]: '%s' -> %d\n", \ |
71 | __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \ |
72 | status__); \ |
73 | fflush(0); \ |
74 | if (s == CRIT) exit(1); \ |
75 | } \ |
76 | return status__; \ |
77 | } \ |
78 | } while (0) |
79 | |
80 | #define SAFE_V(f) \ |
81 | do { \ |
82 | int status__ = (f); \ |
83 | if (status__ != OK) { \ |
84 | fprintf(stderr, "@@@ error [%s:%d]: '%s' -> %d\n", \ |
85 | __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), status__); \ |
86 | fflush(0); \ |
87 | exit(1); \ |
88 | } \ |
89 | } while (0) |
90 | |
91 | extern int verbose; |
92 | extern bool canonical; |
93 | extern bool mem_check; |
94 | extern bool attr_same_pd_check; |
95 | extern std::string skip_impl; /* empty or "" means skip nothing */ |
96 | extern std::string driver_name; |
97 | |
98 | #define BENCHDNN_PRINT(v, fmt, ...) \ |
99 | do { \ |
100 | if (verbose >= v) { \ |
101 | printf(fmt, __VA_ARGS__); \ |
102 | /* printf("[%d][%s:%d]" fmt, v, __func__, __LINE__, __VA_ARGS__); */ \ |
103 | fflush(0); \ |
104 | } \ |
105 | } while (0) |
106 | |
107 | #define BENCHDNN_DISALLOW_COPY_AND_ASSIGN(T) \ |
108 | T(const T &) = delete; \ |
109 | T &operator=(const T &) = delete; |
110 | |
111 | using bench_mode_t = std::bitset<8>; |
112 | extern bench_mode_t RUN, CORR, PERF, LIST, PROF; // pre-defined modes |
113 | extern bench_mode_t bench_mode; // user mode |
114 | |
115 | bool is_bench_mode(bench_mode_t user_mode); |
116 | |
117 | /* perf */ |
118 | extern double max_ms_per_prb; /** maximum time spends per prb in ms */ |
119 | extern int min_times_per_prb; /** minimal amount of runs per prb */ |
120 | extern int fix_times_per_prb; /** if non-zero run prb that many times */ |
121 | |
122 | extern bool fast_ref_gpu; |
123 | extern bool allow_enum_tags_only; |
124 | extern int test_start; |
125 | |
126 | /* global stats */ |
127 | struct stat_t { |
128 | int tests; |
129 | int passed; |
130 | int failed; |
131 | int skipped; |
132 | int mistrusted; |
133 | int unimplemented; |
134 | int invalid_arguments; |
135 | int listed; |
136 | double ms[timer::timer_t::mode_t::n_modes]; |
137 | }; |
138 | extern stat_t benchdnn_stat; |
139 | |
140 | /* result structure */ |
141 | enum res_state_t { |
142 | UNTESTED = 0, |
143 | PASSED, |
144 | SKIPPED, |
145 | MISTRUSTED, |
146 | UNIMPLEMENTED, |
147 | INVALID_ARGUMENTS, |
148 | FAILED, |
149 | LISTED, |
150 | EXECUTED, |
151 | }; |
152 | const char *state2str(res_state_t state); |
153 | |
154 | enum skip_reason_t { |
155 | SKIP_UNKNOWN = 0, |
156 | CASE_NOT_SUPPORTED, |
157 | DATA_TYPE_NOT_SUPPORTED, |
158 | INVALID_CASE, |
159 | NOT_ENOUGH_RAM, |
160 | SKIP_IMPL_HIT, |
161 | SKIP_START, |
162 | }; |
163 | const char *skip_reason2str(skip_reason_t skip_reason); |
164 | |
165 | struct res_t { |
166 | res_state_t state; |
167 | size_t errors, total; |
168 | timer::timer_map_t timer_map; |
169 | std::string impl_name; |
170 | skip_reason_t reason; |
171 | size_t ibytes, obytes; |
172 | }; |
173 | |
174 | void parse_result(res_t &res, const char *pstr); |
175 | |
176 | /* misc */ |
177 | void init_fp_mode(); |
178 | |
179 | void *zmalloc(size_t size, size_t align); |
180 | void zfree(void *ptr); |
181 | |
182 | bool str2bool(const char *str); |
183 | const char *bool2str(bool value); |
184 | |
185 | /* TODO: why two functions??? */ |
186 | bool match_regex(const char *str, const char *pattern); |
187 | bool maybe_skip(const std::string &impl_str); |
188 | |
189 | typedef int (*bench_f)(int argc, char **argv); |
190 | int batch(const char *fname, bench_f bench); |
191 | |
192 | /* returns 1 with given probability */ |
193 | int flip_coin(ptrdiff_t seed, float probability); |
194 | |
195 | int64_t div_up(const int64_t a, const int64_t b); |
196 | int64_t next_pow2(int64_t a); |
197 | int mxcsr_cvt(float f); |
198 | |
199 | /* set '0' across *arr:+size */ |
200 | void array_set(char *arr, size_t size); |
201 | |
202 | /* wrapper to dnnl_sgemm |
203 | * layout = 'F' - column major |
204 | * layout = 'C' - row major*/ |
205 | void gemm(const char *layout, const char *transa, const char *transb, int64_t m, |
206 | int64_t n, int64_t k, const float alpha, const float *a, |
207 | const int64_t lda, const float *b, const int64_t ldb, const float beta, |
208 | float *c, const int64_t ldc); |
209 | |
210 | int sanitize_desc(int &ndims, std::vector<std::reference_wrapper<int64_t>> d, |
211 | std::vector<std::reference_wrapper<int64_t>> h, |
212 | std::vector<std::reference_wrapper<int64_t>> w, |
213 | const std::vector<int64_t> &def_values, bool must_have_spatial = false); |
214 | |
215 | void print_dhw(bool &print_d, bool &print_h, bool &print_w, int ndims, |
216 | const std::vector<int64_t> &d, const std::vector<int64_t> &h, |
217 | const std::vector<int64_t> &w); |
218 | |
219 | #endif |
220 | |