1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_HPP
18#define COMMON_HPP
19
20#include <assert.h>
21#include <float.h>
22#include <math.h>
23#include <stddef.h>
24#include <stdio.h>
25#include <stdlib.h>
26#include <string.h>
27
28#include <bitset>
29#include <cinttypes>
30#include <functional>
31#include <string>
32#include <vector>
33
34#include "src/common/z_magic.hpp"
35
36#include "utils/timer.hpp"
37
38#define ABS(a) ((a) > 0 ? (a) : (-(a)))
39
40#define MIN2(a, b) ((a) < (b) ? (a) : (b))
41#define MAX2(a, b) ((a) > (b) ? (a) : (b))
42
43#define MIN3(a, b, c) MIN2(a, MIN2(b, c))
44#define MAX3(a, b, c) MAX2(a, MAX2(b, c))
45
46#define IMPLICATION(cause, effect) (!(cause) || !!(effect))
47
48#if defined(_WIN32) && !defined(__GNUC__)
49#define strncasecmp _strnicmp
50#define strcasecmp _stricmp
51#define __PRETTY_FUNCTION__ __FUNCSIG__
52#endif
53
54#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
55#define collapse(x)
56#endif
57
58#define IFMT "%" PRId64
59
60#define OK 0
61#define FAIL 1
62
63enum { CRIT = 1, WARN = 2 };
64
65#define SAFE(f, s) \
66 do { \
67 int status__ = (f); \
68 if (status__ != OK) { \
69 if (s == CRIT || s == WARN) { \
70 fprintf(stderr, "@@@ error [%s:%d]: '%s' -> %d\n", \
71 __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), \
72 status__); \
73 fflush(0); \
74 if (s == CRIT) exit(1); \
75 } \
76 return status__; \
77 } \
78 } while (0)
79
80#define SAFE_V(f) \
81 do { \
82 int status__ = (f); \
83 if (status__ != OK) { \
84 fprintf(stderr, "@@@ error [%s:%d]: '%s' -> %d\n", \
85 __PRETTY_FUNCTION__, __LINE__, STRINGIFY(f), status__); \
86 fflush(0); \
87 exit(1); \
88 } \
89 } while (0)
90
91extern int verbose;
92extern bool canonical;
93extern bool mem_check;
94extern bool attr_same_pd_check;
95extern std::string skip_impl; /* empty or "" means skip nothing */
96extern std::string driver_name;
97
98#define BENCHDNN_PRINT(v, fmt, ...) \
99 do { \
100 if (verbose >= v) { \
101 printf(fmt, __VA_ARGS__); \
102 /* printf("[%d][%s:%d]" fmt, v, __func__, __LINE__, __VA_ARGS__); */ \
103 fflush(0); \
104 } \
105 } while (0)
106
107#define BENCHDNN_DISALLOW_COPY_AND_ASSIGN(T) \
108 T(const T &) = delete; \
109 T &operator=(const T &) = delete;
110
111using bench_mode_t = std::bitset<8>;
112extern bench_mode_t RUN, CORR, PERF, LIST, PROF; // pre-defined modes
113extern bench_mode_t bench_mode; // user mode
114
115bool is_bench_mode(bench_mode_t user_mode);
116
117/* perf */
118extern double max_ms_per_prb; /** maximum time spends per prb in ms */
119extern int min_times_per_prb; /** minimal amount of runs per prb */
120extern int fix_times_per_prb; /** if non-zero run prb that many times */
121
122extern bool fast_ref_gpu;
123extern bool allow_enum_tags_only;
124extern int test_start;
125
126/* global stats */
127struct stat_t {
128 int tests;
129 int passed;
130 int failed;
131 int skipped;
132 int mistrusted;
133 int unimplemented;
134 int invalid_arguments;
135 int listed;
136 double ms[timer::timer_t::mode_t::n_modes];
137};
138extern stat_t benchdnn_stat;
139
140/* result structure */
141enum res_state_t {
142 UNTESTED = 0,
143 PASSED,
144 SKIPPED,
145 MISTRUSTED,
146 UNIMPLEMENTED,
147 INVALID_ARGUMENTS,
148 FAILED,
149 LISTED,
150 EXECUTED,
151};
152const char *state2str(res_state_t state);
153
154enum skip_reason_t {
155 SKIP_UNKNOWN = 0,
156 CASE_NOT_SUPPORTED,
157 DATA_TYPE_NOT_SUPPORTED,
158 INVALID_CASE,
159 NOT_ENOUGH_RAM,
160 SKIP_IMPL_HIT,
161 SKIP_START,
162};
163const char *skip_reason2str(skip_reason_t skip_reason);
164
165struct res_t {
166 res_state_t state;
167 size_t errors, total;
168 timer::timer_map_t timer_map;
169 std::string impl_name;
170 skip_reason_t reason;
171 size_t ibytes, obytes;
172};
173
174void parse_result(res_t &res, const char *pstr);
175
176/* misc */
177void init_fp_mode();
178
179void *zmalloc(size_t size, size_t align);
180void zfree(void *ptr);
181
182bool str2bool(const char *str);
183const char *bool2str(bool value);
184
185/* TODO: why two functions??? */
186bool match_regex(const char *str, const char *pattern);
187bool maybe_skip(const std::string &impl_str);
188
189typedef int (*bench_f)(int argc, char **argv);
190int batch(const char *fname, bench_f bench);
191
192/* returns 1 with given probability */
193int flip_coin(ptrdiff_t seed, float probability);
194
195int64_t div_up(const int64_t a, const int64_t b);
196int64_t next_pow2(int64_t a);
197int mxcsr_cvt(float f);
198
199/* set '0' across *arr:+size */
200void array_set(char *arr, size_t size);
201
202/* wrapper to dnnl_sgemm
203 * layout = 'F' - column major
204 * layout = 'C' - row major*/
205void gemm(const char *layout, const char *transa, const char *transb, int64_t m,
206 int64_t n, int64_t k, const float alpha, const float *a,
207 const int64_t lda, const float *b, const int64_t ldb, const float beta,
208 float *c, const int64_t ldc);
209
210int sanitize_desc(int &ndims, std::vector<std::reference_wrapper<int64_t>> d,
211 std::vector<std::reference_wrapper<int64_t>> h,
212 std::vector<std::reference_wrapper<int64_t>> w,
213 const std::vector<int64_t> &def_values, bool must_have_spatial = false);
214
215void print_dhw(bool &print_d, bool &print_h, bool &print_w, int ndims,
216 const std::vector<int64_t> &d, const std::vector<int64_t> &h,
217 const std::vector<int64_t> &w);
218
219#endif
220