1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef BNORM_HPP |
18 | #define BNORM_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <limits.h> |
22 | #include <stdint.h> |
23 | |
24 | #include <iostream> |
25 | #include <string> |
26 | |
27 | #include "common.hpp" |
28 | #include "dnn_types.hpp" |
29 | #include "dnnl_common.hpp" |
30 | #include "dnnl_debug.hpp" |
31 | #include "utils/perf_report.hpp" |
32 | #include "utils/settings.hpp" |
33 | |
34 | #ifdef DNNL_EXPERIMENTAL |
35 | #include "src/common/experimental.hpp" |
36 | #endif |
37 | |
38 | namespace bnorm { |
39 | |
40 | enum check_alg_t { ALG_0, ALG_1, ALG_2, ALG_AUTO }; |
41 | check_alg_t str2check_alg(const char *str); |
42 | const char *check_alg2str(check_alg_t alg); |
43 | |
44 | using flags_t = unsigned; |
45 | const flags_t NONE = dnnl_normalization_flags_none; |
46 | const flags_t GLOB_STATS = dnnl_use_global_stats; |
47 | const flags_t USE_SCALE = dnnl_use_scale; |
48 | const flags_t USE_SHIFT = dnnl_use_shift; |
49 | const flags_t FUSE_NORM_RELU = dnnl_fuse_norm_relu; |
50 | const flags_t FUSE_NORM_ADD_RELU = dnnl_fuse_norm_add_relu; |
51 | flags_t str2flags(const char *str); |
52 | std::string flags2str(flags_t flags); |
53 | |
54 | struct desc_t { |
55 | int64_t mb, ic, id, ih, iw; |
56 | float eps; |
57 | std::string name; |
58 | int ndims; |
59 | |
60 | dims_t data_dims() const; |
61 | }; |
62 | int str2desc(desc_t *desc, const char *str); |
63 | std::ostream &operator<<(std::ostream &s, const desc_t &d); |
64 | |
65 | struct settings_t : public base_settings_t { |
66 | settings_t() = default; |
67 | |
68 | // ctor to save certain fields from resetting |
69 | settings_t(const char *perf_template) : settings_t() { |
70 | this->perf_template = perf_template; |
71 | } |
72 | |
73 | desc_t desc {}; |
74 | |
75 | std::vector<dir_t> dir {FWD_D}; |
76 | std::vector<dnnl_data_type_t> dt {dnnl_f32}; |
77 | std::vector<std::string> tag {tag::abx}; |
78 | std::vector<flags_t> flags {NONE}; |
79 | check_alg_t check_alg = ALG_AUTO; |
80 | bool debug_check_ws = false; |
81 | |
82 | const char *perf_template_csv() const { |
83 | static const std::string args = "%dir%,%dt%,%tag%,%flags%" ; |
84 | return perf_template_csv_base(args); |
85 | } |
86 | |
87 | void reset() { *this = settings_t(perf_template); } |
88 | }; |
89 | |
90 | struct prb_t : public desc_t { |
91 | prb_t(const desc_t &desc, int64_t mb, dir_t dir, dnnl_data_type_t dt, |
92 | const std::string &tag, flags_t flags, bool inplace, |
93 | const attr_t &attr, const thr_ctx_t &ctx_init, |
94 | const thr_ctx_t &ctx_exe, check_alg_t check_alg, |
95 | bool debug_check_ws) |
96 | : desc_t(desc) |
97 | , check_alg(check_alg) |
98 | , debug_check_ws(debug_check_ws) |
99 | , dir(dir) |
100 | , dt(dt) |
101 | , tag(tag) |
102 | , flags(flags) |
103 | , inplace(inplace) |
104 | , attr(attr) |
105 | , ctx_init(ctx_init) |
106 | , ctx_exe(ctx_exe) |
107 | , user_mb(mb) { |
108 | if (mb) this->mb = mb; |
109 | } |
110 | ~prb_t() {} |
111 | |
112 | check_alg_t check_alg; |
113 | bool debug_check_ws; |
114 | |
115 | dir_t dir; |
116 | dnnl_data_type_t dt; |
117 | std::string tag; |
118 | flags_t flags; |
119 | bool inplace; |
120 | attr_t attr; |
121 | const thr_ctx_t ctx_init, ctx_exe; |
122 | int64_t user_mb; |
123 | |
124 | bool need_ws() const { |
125 | return (flags & (FUSE_NORM_RELU | FUSE_NORM_ADD_RELU)) |
126 | && !(dir & FLAG_INF); |
127 | } |
128 | |
129 | bool use_sc() const { return flags & USE_SCALE; } |
130 | bool use_sh() const { return flags & USE_SHIFT; } |
131 | bool fuse_relu() const { |
132 | return flags & (FUSE_NORM_RELU | FUSE_NORM_ADD_RELU); |
133 | } |
134 | bool fuse_add_relu() const { return flags & FUSE_NORM_ADD_RELU; } |
135 | }; |
136 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
137 | |
138 | struct perf_report_t : public base_perf_report_t { |
139 | perf_report_t(const prb_t *prb, const char *perf_template) |
140 | : base_perf_report_t(perf_template) |
141 | , p_(prb) |
142 | , tag_(normalize_tag(p_->tag, p_->ndims)) {} |
143 | |
144 | void dump_desc(std::ostream &s) const override { |
145 | s << static_cast<const desc_t &>(*p_); |
146 | } |
147 | |
148 | void dump_desc_csv(std::ostream &s) const override { |
149 | s << p_->mb << ',' << p_->ic << ',' << p_->id << ',' << p_->ih << ',' |
150 | << p_->iw << ',' << p_->eps; |
151 | } |
152 | |
153 | void dump_flags(std::ostream &s) const override { |
154 | s << flags2str(p_->flags); |
155 | } |
156 | |
157 | const attr_t *attr() const override { return &p_->attr; } |
158 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
159 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
160 | const int64_t *user_mb() const override { return &p_->user_mb; } |
161 | const std::string *name() const override { return &p_->name; } |
162 | const dir_t *dir() const override { return &p_->dir; } |
163 | const dnnl_data_type_t *dt() const override { return &p_->dt; } |
164 | const std::string *tag() const override { return &tag_; } |
165 | |
166 | private: |
167 | const prb_t *p_; |
168 | std::string tag_; |
169 | }; |
170 | |
171 | /* some extra control parameters which shouldn't be placed in prb_t */ |
172 | |
173 | inline size_t data_off(const prb_t *prb, int64_t mb, int64_t c, int64_t d, |
174 | int64_t h, int64_t w) { |
175 | return (((mb * prb->ic + c) * prb->id + d) * prb->ih + h) * prb->iw + w; |
176 | } |
177 | |
178 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
179 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
180 | void compute_ref(const prb_t *prb, const args_t &args, |
181 | dnnl_primitive_t prim_ref = nullptr); |
182 | |
183 | int doit(const prb_t *prb, res_t *res); |
184 | int bench(int argc, char **argv); |
185 | |
186 | } // namespace bnorm |
187 | |
188 | #endif |
189 | |