1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef BNORM_HPP
18#define BNORM_HPP
19
20#include <assert.h>
21#include <limits.h>
22#include <stdint.h>
23
24#include <iostream>
25#include <string>
26
27#include "common.hpp"
28#include "dnn_types.hpp"
29#include "dnnl_common.hpp"
30#include "dnnl_debug.hpp"
31#include "utils/perf_report.hpp"
32#include "utils/settings.hpp"
33
34#ifdef DNNL_EXPERIMENTAL
35#include "src/common/experimental.hpp"
36#endif
37
38namespace bnorm {
39
40enum check_alg_t { ALG_0, ALG_1, ALG_2, ALG_AUTO };
41check_alg_t str2check_alg(const char *str);
42const char *check_alg2str(check_alg_t alg);
43
44using flags_t = unsigned;
45const flags_t NONE = dnnl_normalization_flags_none;
46const flags_t GLOB_STATS = dnnl_use_global_stats;
47const flags_t USE_SCALE = dnnl_use_scale;
48const flags_t USE_SHIFT = dnnl_use_shift;
49const flags_t FUSE_NORM_RELU = dnnl_fuse_norm_relu;
50const flags_t FUSE_NORM_ADD_RELU = dnnl_fuse_norm_add_relu;
51flags_t str2flags(const char *str);
52std::string flags2str(flags_t flags);
53
54struct desc_t {
55 int64_t mb, ic, id, ih, iw;
56 float eps;
57 std::string name;
58 int ndims;
59
60 dims_t data_dims() const;
61};
62int str2desc(desc_t *desc, const char *str);
63std::ostream &operator<<(std::ostream &s, const desc_t &d);
64
65struct settings_t : public base_settings_t {
66 settings_t() = default;
67
68 // ctor to save certain fields from resetting
69 settings_t(const char *perf_template) : settings_t() {
70 this->perf_template = perf_template;
71 }
72
73 desc_t desc {};
74
75 std::vector<dir_t> dir {FWD_D};
76 std::vector<dnnl_data_type_t> dt {dnnl_f32};
77 std::vector<std::string> tag {tag::abx};
78 std::vector<flags_t> flags {NONE};
79 check_alg_t check_alg = ALG_AUTO;
80 bool debug_check_ws = false;
81
82 const char *perf_template_csv() const {
83 static const std::string args = "%dir%,%dt%,%tag%,%flags%";
84 return perf_template_csv_base(args);
85 }
86
87 void reset() { *this = settings_t(perf_template); }
88};
89
90struct prb_t : public desc_t {
91 prb_t(const desc_t &desc, int64_t mb, dir_t dir, dnnl_data_type_t dt,
92 const std::string &tag, flags_t flags, bool inplace,
93 const attr_t &attr, const thr_ctx_t &ctx_init,
94 const thr_ctx_t &ctx_exe, check_alg_t check_alg,
95 bool debug_check_ws)
96 : desc_t(desc)
97 , check_alg(check_alg)
98 , debug_check_ws(debug_check_ws)
99 , dir(dir)
100 , dt(dt)
101 , tag(tag)
102 , flags(flags)
103 , inplace(inplace)
104 , attr(attr)
105 , ctx_init(ctx_init)
106 , ctx_exe(ctx_exe)
107 , user_mb(mb) {
108 if (mb) this->mb = mb;
109 }
110 ~prb_t() {}
111
112 check_alg_t check_alg;
113 bool debug_check_ws;
114
115 dir_t dir;
116 dnnl_data_type_t dt;
117 std::string tag;
118 flags_t flags;
119 bool inplace;
120 attr_t attr;
121 const thr_ctx_t ctx_init, ctx_exe;
122 int64_t user_mb;
123
124 bool need_ws() const {
125 return (flags & (FUSE_NORM_RELU | FUSE_NORM_ADD_RELU))
126 && !(dir & FLAG_INF);
127 }
128
129 bool use_sc() const { return flags & USE_SCALE; }
130 bool use_sh() const { return flags & USE_SHIFT; }
131 bool fuse_relu() const {
132 return flags & (FUSE_NORM_RELU | FUSE_NORM_ADD_RELU);
133 }
134 bool fuse_add_relu() const { return flags & FUSE_NORM_ADD_RELU; }
135};
136std::ostream &operator<<(std::ostream &s, const prb_t &prb);
137
138struct perf_report_t : public base_perf_report_t {
139 perf_report_t(const prb_t *prb, const char *perf_template)
140 : base_perf_report_t(perf_template)
141 , p_(prb)
142 , tag_(normalize_tag(p_->tag, p_->ndims)) {}
143
144 void dump_desc(std::ostream &s) const override {
145 s << static_cast<const desc_t &>(*p_);
146 }
147
148 void dump_desc_csv(std::ostream &s) const override {
149 s << p_->mb << ',' << p_->ic << ',' << p_->id << ',' << p_->ih << ','
150 << p_->iw << ',' << p_->eps;
151 }
152
153 void dump_flags(std::ostream &s) const override {
154 s << flags2str(p_->flags);
155 }
156
157 const attr_t *attr() const override { return &p_->attr; }
158 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
159 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
160 const int64_t *user_mb() const override { return &p_->user_mb; }
161 const std::string *name() const override { return &p_->name; }
162 const dir_t *dir() const override { return &p_->dir; }
163 const dnnl_data_type_t *dt() const override { return &p_->dt; }
164 const std::string *tag() const override { return &tag_; }
165
166private:
167 const prb_t *p_;
168 std::string tag_;
169};
170
171/* some extra control parameters which shouldn't be placed in prb_t */
172
173inline size_t data_off(const prb_t *prb, int64_t mb, int64_t c, int64_t d,
174 int64_t h, int64_t w) {
175 return (((mb * prb->ic + c) * prb->id + d) * prb->ih + h) * prb->iw + w;
176}
177
178void skip_unimplemented_prb(const prb_t *prb, res_t *res);
179void skip_invalid_prb(const prb_t *prb, res_t *res);
180void compute_ref(const prb_t *prb, const args_t &args,
181 dnnl_primitive_t prim_ref = nullptr);
182
183int doit(const prb_t *prb, res_t *res);
184int bench(int argc, char **argv);
185
186} // namespace bnorm
187
188#endif
189