1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef IP_HPP |
18 | #define IP_HPP |
19 | |
20 | #include <iostream> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "common.hpp" |
25 | #include "dnnl_common.hpp" |
26 | #include "utils/perf_report.hpp" |
27 | #include "utils/settings.hpp" |
28 | |
29 | namespace ip { |
30 | |
31 | struct desc_t { |
32 | int64_t mb, oc, ic, id, ih, iw; |
33 | std::string name; |
34 | int ndims; |
35 | |
36 | dims_t src_dims() const; |
37 | dims_t wei_dims() const; |
38 | dims_t bia_dims() const; |
39 | dims_t dst_dims() const; |
40 | int64_t desc_nelems(int arg, int mask) const; |
41 | }; |
42 | int str2desc(desc_t *desc, const char *str); |
43 | std::ostream &operator<<(std::ostream &s, const desc_t &d); |
44 | |
45 | typedef struct dt_conf_t { |
46 | dnnl_data_type_t dt; |
47 | double min, max; /* representative */ |
48 | double f_min, f_max; /* fill range */ |
49 | int f_base; /* fill base, use 0 */ |
50 | double f_sparsity; /* amount of non-zeros, default 0.25 */ |
51 | double f_scale; /* fill scale, scaling factor for integer generated data */ |
52 | double eps; /* acceptable error */ |
53 | } _dt_conf_t[DAT_TOTAL]; |
54 | |
55 | extern const _dt_conf_t conf_f32; |
56 | extern const _dt_conf_t conf_bf16bf16f32; |
57 | |
58 | struct settings_t : public base_settings_t { |
59 | settings_t() = default; |
60 | |
61 | // ctor to save certain fields from resetting |
62 | settings_t(const char *perf_template) : settings_t() { |
63 | this->perf_template = perf_template; |
64 | } |
65 | |
66 | desc_t desc {}; |
67 | |
68 | std::vector<dir_t> dir {FWD_B}; |
69 | std::vector<const dt_conf_t *> cfg {conf_f32}; |
70 | std::vector<std::string> stag {tag::any}, wtag {tag::any}, dtag {tag::any}; |
71 | |
72 | const char *perf_template_csv() const { |
73 | static const std::string args = "%dir%,%cfg%,%stag%,%wtag%,%dtag%" ; |
74 | return perf_template_csv_base(args); |
75 | } |
76 | |
77 | void reset() { *this = settings_t(perf_template); } |
78 | }; |
79 | |
80 | struct prb_t : public desc_t { |
81 | prb_t(const desc_t &desc, int64_t mb, dir_t dir, const dt_conf_t *cfg, |
82 | const std::string &stag, const std::string &wtag, |
83 | const std::string &dtag, const attr_t &attr, |
84 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe) |
85 | : desc_t(desc) |
86 | , dir(dir) |
87 | , cfg(cfg) |
88 | , stag(stag) |
89 | , wtag(wtag) |
90 | , dtag(dtag) |
91 | , attr(attr) |
92 | , ctx_init(ctx_init) |
93 | , ctx_exe(ctx_exe) |
94 | , user_mb(mb) |
95 | , ops(0) |
96 | , src_scales(NULL) |
97 | , wei_scales(NULL) |
98 | , dst_scales(NULL) { |
99 | if (mb) this->mb = mb; |
100 | count_ops(); |
101 | src_scales = generate_scales(DNNL_ARG_SRC); |
102 | wei_scales = generate_scales(DNNL_ARG_WEIGHTS); |
103 | dst_scales = generate_scales(DNNL_ARG_DST); |
104 | } |
105 | ~prb_t() { |
106 | if (src_scales) zfree(src_scales); |
107 | if (wei_scales) zfree(wei_scales); |
108 | if (dst_scales) zfree(dst_scales); |
109 | } |
110 | |
111 | dir_t dir; |
112 | const dt_conf_t *cfg; |
113 | std::string stag, wtag, dtag; |
114 | attr_t attr; |
115 | thr_ctx_t ctx_init, ctx_exe; |
116 | int64_t user_mb; |
117 | |
118 | double ops; |
119 | float *src_scales, *wei_scales, *dst_scales; |
120 | |
121 | void count_ops() { |
122 | if (ops > 0) return; |
123 | ops = 2. * mb * ic * oc * id * ih * iw; |
124 | }; |
125 | |
126 | dt_conf_t get_dt_conf(data_kind_t dk) const { |
127 | return (attr.fpmath_mode == dnnl_fpmath_mode_bf16 && cfg == conf_f32) |
128 | ? conf_bf16bf16f32[dk] |
129 | : cfg[dk]; |
130 | } |
131 | |
132 | float *generate_scales(int arg) const; |
133 | |
134 | BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t); |
135 | }; |
136 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
137 | |
138 | const dt_conf_t *str2cfg(const char *str); |
139 | std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg); |
140 | |
141 | struct perf_report_t : public base_perf_report_t { |
142 | perf_report_t(const prb_t *prb, const char *perf_template) |
143 | : base_perf_report_t(perf_template) |
144 | , p_(prb) |
145 | , stag_({normalize_tag(p_->stag, p_->ndims)}) |
146 | , wtag_(normalize_tag(p_->wtag, p_->ndims)) |
147 | , dtag_(normalize_tag(p_->dtag, p_->ndims)) {} |
148 | |
149 | void dump_cfg(std::ostream &s) const override { s << p_->cfg; } |
150 | |
151 | void dump_desc(std::ostream &s) const override { |
152 | s << static_cast<const desc_t &>(*p_); |
153 | } |
154 | |
155 | void dump_desc_csv(std::ostream &s) const override { |
156 | s << p_->mb << ',' << p_->oc << ',' << p_->ic << ',' << p_->id << ',' |
157 | << p_->ih << ',' << p_->iw; |
158 | } |
159 | |
160 | double ops() const override { return p_->ops; } |
161 | const attr_t *attr() const override { return &p_->attr; } |
162 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
163 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
164 | const int64_t *user_mb() const override { return &p_->user_mb; } |
165 | const std::string *name() const override { return &p_->name; } |
166 | const dir_t *dir() const override { return &p_->dir; } |
167 | const std::vector<std::string> *stag() const override { return &stag_; } |
168 | const std::string *wtag() const override { return &wtag_; } |
169 | const std::string *dtag() const override { return &dtag_; } |
170 | |
171 | private: |
172 | const prb_t *p_; |
173 | std::vector<std::string> stag_; |
174 | std::string wtag_, dtag_; |
175 | }; |
176 | |
177 | inline size_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id, |
178 | int64_t ih, int64_t iw) { |
179 | return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw; |
180 | } |
181 | |
182 | inline size_t wei_off_f(const prb_t *prb, int64_t oc, int64_t ic, int64_t id, |
183 | int64_t ih, int64_t iw) { |
184 | return (((oc * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw; |
185 | } |
186 | |
187 | inline size_t bia_off_f(const prb_t *prb, int64_t oc) { |
188 | return oc; |
189 | } |
190 | |
191 | inline size_t dst_off_f(const prb_t *prb, int64_t mb, int64_t oc) { |
192 | return mb * prb->oc + oc; |
193 | } |
194 | |
195 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
196 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
197 | void compute_ref(const prb_t *prb, const args_t &args, |
198 | dnnl_primitive_t prim_ref = nullptr); |
199 | |
200 | int doit(const prb_t *prb, res_t *res); |
201 | |
202 | int bench(int argc, char **argv); |
203 | } // namespace ip |
204 | |
205 | #endif |
206 | |