1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef BRGEMM_HPP |
18 | #define BRGEMM_HPP |
19 | |
20 | #include <algorithm> |
21 | #include <bitset> |
22 | #include <iostream> |
23 | #include <map> |
24 | #include <numeric> |
25 | |
26 | #include "oneapi/dnnl/dnnl.h" |
27 | |
28 | #if defined(DNNL_X64) && DNNL_X64 == 1 \ |
29 | && (DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE) |
30 | #include "src/cpu/x64/brgemm/brgemm.hpp" |
31 | #endif |
32 | |
33 | #include "common.hpp" |
34 | #include "dnnl_common.hpp" |
35 | #include "utils/cfg.hpp" |
36 | #include "utils/perf_report.hpp" |
37 | #include "utils/settings.hpp" |
38 | |
39 | namespace brgemm { |
40 | |
41 | struct settings_t : public base_settings_t { |
42 | settings_t() = default; |
43 | |
44 | // ctor to save certain fields from resetting |
45 | settings_t(const char *perf_template) : settings_t() { |
46 | this->perf_template = perf_template; |
47 | } |
48 | |
49 | prb_vdims_t prb_vdims; |
50 | |
51 | std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}}; |
52 | std::vector<std::string> stag {tag::abx}, wtag {tag::undef}, |
53 | dtag {tag::abx}; |
54 | std::vector<std::vector<int64_t>> ld {{}}; |
55 | std::vector<dnnl_data_type_t> bia_dt {dnnl_data_type_undef}; |
56 | std::vector<int> batch_size {1}; |
57 | std::vector<float> alpha {1.f}, beta {0.f}; |
58 | std::vector<std::string> brgemm_attr {std::string()}; |
59 | |
60 | const char *perf_template_csv() const { |
61 | static const std::string args = "" ; |
62 | return perf_template_csv_base(args); |
63 | } |
64 | |
65 | void reset() { *this = settings_t(perf_template); } |
66 | }; |
67 | |
68 | struct prb_t : public prb_vdims_t { |
69 | prb_t(const prb_vdims_t &prb_vdims, const std::vector<dnnl_data_type_t> &dt, |
70 | const std::string &stag, const std::string &wtag, |
71 | const std::string &dtag, const std::vector<int64_t> &ld, |
72 | dnnl_data_type_t bia_dt, float alpha, float beta, int batch_size, |
73 | const std::string &brgemm_attr, const attr_t &attr, |
74 | const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe) |
75 | : prb_vdims_t(prb_vdims) |
76 | , dt(dt) |
77 | , stag(stag) |
78 | , wtag(wtag) |
79 | , dtag(dtag) |
80 | , ld(ld) |
81 | , bia_dt(bia_dt) |
82 | , alpha(alpha) |
83 | , beta(beta) |
84 | , batch_size(batch_size) |
85 | , brgemm_attr(brgemm_attr) |
86 | , attr(attr) |
87 | , ctx_init(ctx_init) |
88 | , ctx_exe(ctx_exe) |
89 | , scales(NULL) { |
90 | |
91 | // Broadcast data types if needed |
92 | if (dt.size() == 1) { |
93 | const auto val = dt[0]; // Need a copy here. |
94 | this->dt.assign(3, val); |
95 | } |
96 | |
97 | const auto &srcdims = src_dims(); |
98 | const auto &weidims = weights_dims(); |
99 | m = srcdims[ndims - 2]; |
100 | k = srcdims.back(); |
101 | n = weidims.back(); |
102 | dst_dims[ndims - 2] = m; |
103 | dst_dims[ndims - 1] = n; |
104 | |
105 | const auto nelems = std::accumulate(dst_dims.begin(), dst_dims.end(), |
106 | (dnnl_dim_t)1, std::multiplies<dnnl_dim_t>()); |
107 | ops = 2. * nelems * k; |
108 | |
109 | generate_oscales(); |
110 | src_zp = generate_zero_points(DNNL_ARG_SRC, attr.zero_points, k); |
111 | dst_zp = generate_zero_points(DNNL_ARG_DST, attr.zero_points, n); |
112 | } |
113 | ~prb_t() { |
114 | if (scales) zfree(scales); |
115 | if (src_zp) zfree(src_zp); |
116 | if (dst_zp) zfree(dst_zp); |
117 | } |
118 | |
119 | int m, n, k; |
120 | dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward. |
121 | std::vector<dnnl_data_type_t> dt; |
122 | std::string stag, wtag, dtag; |
123 | std::vector<int64_t> ld; |
124 | dnnl_data_type_t bia_dt; |
125 | float alpha, beta; |
126 | int batch_size; |
127 | std::string brgemm_attr; |
128 | |
129 | attr_t attr; |
130 | thr_ctx_t ctx_init, ctx_exe; |
131 | |
132 | double ops; |
133 | float *scales; |
134 | int32_t *src_zp, *dst_zp; |
135 | |
136 | const dims_t &src_dims() const { return vdims[0]; } |
137 | const dims_t &weights_dims() const { return vdims[1]; } |
138 | // const dims_t &prb_vdims_t::dst_dims() const; |
139 | |
140 | dnnl_data_type_t src_dt() const { return dt[0]; } |
141 | dnnl_data_type_t wei_dt() const { return dt[1]; } |
142 | dnnl_data_type_t acc_dt() const { |
143 | return is_integral_dt(wei_dt()) ? dnnl_s32 : dnnl_f32; |
144 | } |
145 | dnnl_data_type_t dst_dt() const { return dt[2]; } |
146 | dnnl_data_type_t get_dt(data_kind_t data_kind) const; |
147 | |
148 | int64_t get_lda() const { |
149 | if (!ld.empty() && ld[0] != 0) { |
150 | assert(ld[0] >= batch_size * k); |
151 | return ld[0]; |
152 | } |
153 | return batch_size * k; |
154 | } |
155 | int64_t get_ldb() const { |
156 | if (!ld.empty() && ld[1] != 0) { |
157 | assert(ld[1] >= n); |
158 | return ld[1]; |
159 | } |
160 | return n; |
161 | } |
162 | int64_t get_ldc(bool use_dst_as_acc) const { |
163 | if (use_dst_as_acc) return get_ldd(); |
164 | return n; |
165 | } |
166 | int64_t get_ldd() const { |
167 | if (!ld.empty() && ld[2] != 0) { |
168 | assert(ld[2] >= n); |
169 | return ld[2]; |
170 | } |
171 | return n; |
172 | } |
173 | |
174 | void generate_oscales(); |
175 | int32_t *generate_zero_points( |
176 | int arg, const attr_t::zero_points_t &zero_points, int N); |
177 | |
178 | BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t); |
179 | }; |
180 | std::ostream &operator<<(std::ostream &s, const prb_t &prb); |
181 | |
182 | // TODO: not supported as of now. |
183 | struct perf_report_t : public base_perf_report_t { |
184 | perf_report_t(const prb_t *prb, const char *perf_template) |
185 | : base_perf_report_t(perf_template) |
186 | , p_(prb) |
187 | , stag_({normalize_tag(p_->stag, p_->ndims)}) |
188 | , wtag_(normalize_tag(p_->wtag, p_->ndims)) |
189 | , dtag_(normalize_tag(p_->dtag, p_->ndims)) {} |
190 | |
191 | void dump_desc(std::ostream &s) const override { |
192 | s << static_cast<const prb_vdims_t &>(*p_); |
193 | } |
194 | |
195 | void dump_desc_csv(std::ostream &s) const override { dump_desc(s); } |
196 | |
197 | double ops() const override { return p_->ops; } |
198 | const std::vector<dnnl_data_type_t> *sdt() const override { |
199 | return &p_->dt; |
200 | } |
201 | const attr_t *attr() const override { return &p_->attr; } |
202 | const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; } |
203 | const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; } |
204 | const std::string *name() const override { return &p_->name; } |
205 | const std::vector<std::string> *stag() const override { return &stag_; } |
206 | const std::string *wtag() const override { return &wtag_; } |
207 | const std::string *dtag() const override { return &dtag_; } |
208 | |
209 | private: |
210 | const prb_t *p_; |
211 | std::vector<std::string> stag_; |
212 | std::string wtag_, dtag_; |
213 | }; |
214 | |
215 | struct cfg_t : public base_cfg_t { |
216 | cfg_t(const prb_t *prb, std::vector<data_kind_t> kinds) { |
217 | for (const auto kind : kinds) { |
218 | auto orig_data_type = prb->get_dt(kind); |
219 | auto data_type |
220 | = deduce_cfg_data_type(orig_data_type, prb->attr, kind); |
221 | cfg_entry_.push_back(cfg_entry_t( |
222 | kind, orig_data_type, data_type, get_cfg_map(kind))); |
223 | } |
224 | } |
225 | |
226 | const cfg_entry_t::cfg_map_t &get_cfg_map(data_kind_t kind) const; |
227 | |
228 | float get_density(const density_args_t &density_args) const override; |
229 | }; |
230 | |
231 | void skip_unimplemented_prb(const prb_t *prb, res_t *res); |
232 | void skip_invalid_prb(const prb_t *prb, res_t *res); |
233 | void compute_ref(const prb_t *prb, const args_t &args, |
234 | dnnl_primitive_t prim_ref = nullptr); |
235 | |
236 | int doit(const prb_t *prb, res_t *res); |
237 | |
238 | int bench(int argc, char **argv); |
239 | |
240 | } // namespace brgemm |
241 | |
242 | #endif |
243 | |