1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef IP_HPP
18#define IP_HPP
19
20#include <iostream>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "common.hpp"
25#include "dnnl_common.hpp"
26#include "utils/perf_report.hpp"
27#include "utils/settings.hpp"
28
29namespace ip {
30
31struct desc_t {
32 int64_t mb, oc, ic, id, ih, iw;
33 std::string name;
34 int ndims;
35
36 dims_t src_dims() const;
37 dims_t wei_dims() const;
38 dims_t bia_dims() const;
39 dims_t dst_dims() const;
40 int64_t desc_nelems(int arg, int mask) const;
41};
42int str2desc(desc_t *desc, const char *str);
43std::ostream &operator<<(std::ostream &s, const desc_t &d);
44
45typedef struct dt_conf_t {
46 dnnl_data_type_t dt;
47 double min, max; /* representative */
48 double f_min, f_max; /* fill range */
49 int f_base; /* fill base, use 0 */
50 double f_sparsity; /* amount of non-zeros, default 0.25 */
51 double f_scale; /* fill scale, scaling factor for integer generated data */
52 double eps; /* acceptable error */
53} _dt_conf_t[DAT_TOTAL];
54
55extern const _dt_conf_t conf_f32;
56extern const _dt_conf_t conf_bf16bf16f32;
57
58struct settings_t : public base_settings_t {
59 settings_t() = default;
60
61 // ctor to save certain fields from resetting
62 settings_t(const char *perf_template) : settings_t() {
63 this->perf_template = perf_template;
64 }
65
66 desc_t desc {};
67
68 std::vector<dir_t> dir {FWD_B};
69 std::vector<const dt_conf_t *> cfg {conf_f32};
70 std::vector<std::string> stag {tag::any}, wtag {tag::any}, dtag {tag::any};
71
72 const char *perf_template_csv() const {
73 static const std::string args = "%dir%,%cfg%,%stag%,%wtag%,%dtag%";
74 return perf_template_csv_base(args);
75 }
76
77 void reset() { *this = settings_t(perf_template); }
78};
79
80struct prb_t : public desc_t {
81 prb_t(const desc_t &desc, int64_t mb, dir_t dir, const dt_conf_t *cfg,
82 const std::string &stag, const std::string &wtag,
83 const std::string &dtag, const attr_t &attr,
84 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
85 : desc_t(desc)
86 , dir(dir)
87 , cfg(cfg)
88 , stag(stag)
89 , wtag(wtag)
90 , dtag(dtag)
91 , attr(attr)
92 , ctx_init(ctx_init)
93 , ctx_exe(ctx_exe)
94 , user_mb(mb)
95 , ops(0)
96 , src_scales(NULL)
97 , wei_scales(NULL)
98 , dst_scales(NULL) {
99 if (mb) this->mb = mb;
100 count_ops();
101 src_scales = generate_scales(DNNL_ARG_SRC);
102 wei_scales = generate_scales(DNNL_ARG_WEIGHTS);
103 dst_scales = generate_scales(DNNL_ARG_DST);
104 }
105 ~prb_t() {
106 if (src_scales) zfree(src_scales);
107 if (wei_scales) zfree(wei_scales);
108 if (dst_scales) zfree(dst_scales);
109 }
110
111 dir_t dir;
112 const dt_conf_t *cfg;
113 std::string stag, wtag, dtag;
114 attr_t attr;
115 thr_ctx_t ctx_init, ctx_exe;
116 int64_t user_mb;
117
118 double ops;
119 float *src_scales, *wei_scales, *dst_scales;
120
121 void count_ops() {
122 if (ops > 0) return;
123 ops = 2. * mb * ic * oc * id * ih * iw;
124 };
125
126 dt_conf_t get_dt_conf(data_kind_t dk) const {
127 return (attr.fpmath_mode == dnnl_fpmath_mode_bf16 && cfg == conf_f32)
128 ? conf_bf16bf16f32[dk]
129 : cfg[dk];
130 }
131
132 float *generate_scales(int arg) const;
133
134 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
135};
136std::ostream &operator<<(std::ostream &s, const prb_t &prb);
137
138const dt_conf_t *str2cfg(const char *str);
139std::ostream &operator<<(std::ostream &s, const dt_conf_t *cfg);
140
141struct perf_report_t : public base_perf_report_t {
142 perf_report_t(const prb_t *prb, const char *perf_template)
143 : base_perf_report_t(perf_template)
144 , p_(prb)
145 , stag_({normalize_tag(p_->stag, p_->ndims)})
146 , wtag_(normalize_tag(p_->wtag, p_->ndims))
147 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {}
148
149 void dump_cfg(std::ostream &s) const override { s << p_->cfg; }
150
151 void dump_desc(std::ostream &s) const override {
152 s << static_cast<const desc_t &>(*p_);
153 }
154
155 void dump_desc_csv(std::ostream &s) const override {
156 s << p_->mb << ',' << p_->oc << ',' << p_->ic << ',' << p_->id << ','
157 << p_->ih << ',' << p_->iw;
158 }
159
160 double ops() const override { return p_->ops; }
161 const attr_t *attr() const override { return &p_->attr; }
162 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
163 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
164 const int64_t *user_mb() const override { return &p_->user_mb; }
165 const std::string *name() const override { return &p_->name; }
166 const dir_t *dir() const override { return &p_->dir; }
167 const std::vector<std::string> *stag() const override { return &stag_; }
168 const std::string *wtag() const override { return &wtag_; }
169 const std::string *dtag() const override { return &dtag_; }
170
171private:
172 const prb_t *p_;
173 std::vector<std::string> stag_;
174 std::string wtag_, dtag_;
175};
176
177inline size_t src_off_f(const prb_t *prb, int64_t mb, int64_t ic, int64_t id,
178 int64_t ih, int64_t iw) {
179 return (((mb * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
180}
181
182inline size_t wei_off_f(const prb_t *prb, int64_t oc, int64_t ic, int64_t id,
183 int64_t ih, int64_t iw) {
184 return (((oc * prb->ic + ic) * prb->id + id) * prb->ih + ih) * prb->iw + iw;
185}
186
187inline size_t bia_off_f(const prb_t *prb, int64_t oc) {
188 return oc;
189}
190
191inline size_t dst_off_f(const prb_t *prb, int64_t mb, int64_t oc) {
192 return mb * prb->oc + oc;
193}
194
195void skip_unimplemented_prb(const prb_t *prb, res_t *res);
196void skip_invalid_prb(const prb_t *prb, res_t *res);
197void compute_ref(const prb_t *prb, const args_t &args,
198 dnnl_primitive_t prim_ref = nullptr);
199
200int doit(const prb_t *prb, res_t *res);
201
202int bench(int argc, char **argv);
203} // namespace ip
204
205#endif
206