1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef MATMUL_HPP
18#define MATMUL_HPP
19
20#include <algorithm>
21#include <bitset>
22#include <iostream>
23#include <map>
24#include <numeric>
25
26#include "oneapi/dnnl/dnnl.h"
27
28#include "common.hpp"
29#include "dnnl_common.hpp"
30#include "utils/cfg.hpp"
31#include "utils/perf_report.hpp"
32#include "utils/settings.hpp"
33
34namespace matmul {
35
36typedef std::bitset<DNNL_MAX_NDIMS> dims_mask_t;
37
38const int64_t LD_GOOD = INT64_MAX;
39const int64_t LD_NONE = INT64_MAX - 1;
40
41struct settings_t : public base_settings_t {
42 settings_t() = default;
43
44 // ctor to save certain fields from resetting
45 settings_t(const char *perf_template) : settings_t() {
46 this->perf_template = perf_template;
47 }
48
49 prb_vdims_t prb_vdims;
50
51 std::vector<std::string> cfg {std::string()};
52 std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}};
53 std::vector<std::string> stag {tag::any}, wtag {tag::any}, dtag {tag::any};
54 std::vector<vdims_t> strides {vdims_t(STRIDES_SIZE)};
55 std::vector<dnnl_data_type_t> bia_dt {dnnl_data_type_undef};
56 std::vector<int> bia_mask {2};
57 std::vector<std::vector<dims_mask_t>> rt_dims_masks {{}};
58
59 const char *perf_template_csv() const {
60 static const std::string args = "%cfg%,%stag%,%wtag%,%dtag%";
61 return perf_template_csv_base(args);
62 }
63
64 void reset() { *this = settings_t(perf_template); }
65};
66
67struct prb_t : public prb_vdims_t {
68 prb_t(const prb_vdims_t &prb_vdims, const std::vector<dnnl_data_type_t> &dt,
69 const std::string &stag, const std::string &wtag,
70 const std::string &dtag, const vdims_t &strides,
71 dnnl_data_type_t bia_dt, int bia_mask,
72 const std::vector<dims_mask_t> &rt_dims_masks, const attr_t &attr,
73 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
74 : prb_vdims_t(prb_vdims)
75 , dt(dt)
76 , stag(stag)
77 , wtag(wtag)
78 , dtag(dtag)
79 , strides(strides)
80 , bia_dt(bia_dt)
81 , bia_mask(bia_mask)
82 , rt_dims_masks(rt_dims_masks)
83 , attr(attr)
84 , ctx_init(ctx_init)
85 , ctx_exe(ctx_exe)
86 , src_scales(NULL)
87 , wei_scales(NULL)
88 , dst_scales(NULL) {
89
90 // Broadcast data types if needed
91 if (dt.size() == 1) {
92 const auto val = dt[0]; // Need a copy here.
93 this->dt.assign(3, val);
94 }
95
96 this->rt_dims_masks.resize(2);
97 const auto &srcdims = src_dims();
98 const auto &weidims = weights_dims();
99 m = srcdims[ndims - 2];
100 k = srcdims.back();
101 n = weidims.back();
102 dst_dims[ndims - 2] = m;
103 dst_dims[ndims - 1] = n;
104
105 init_dst_rt_dims_mask();
106 mb = std::accumulate(dst_dims.begin(), dst_dims.end() - 2,
107 (dnnl_dim_t)1, std::multiplies<dnnl_dim_t>());
108 const auto nelems = std::accumulate(dst_dims.begin(), dst_dims.end(),
109 (dnnl_dim_t)1, std::multiplies<dnnl_dim_t>());
110 ops = 2. * nelems * k;
111
112 src_scales = generate_scales(DNNL_ARG_SRC);
113 wei_scales = generate_scales(DNNL_ARG_WEIGHTS);
114 dst_scales = generate_scales(DNNL_ARG_DST);
115 src_zp = generate_zero_points(DNNL_ARG_SRC, attr.zero_points, k);
116 dst_zp = generate_zero_points(DNNL_ARG_DST, attr.zero_points, n);
117 }
118 ~prb_t() {
119 if (src_scales) zfree(src_scales);
120 if (wei_scales) zfree(wei_scales);
121 if (dst_scales) zfree(dst_scales);
122 if (src_zp) zfree(src_zp);
123 if (dst_zp) zfree(dst_zp);
124 }
125
126 int64_t m, n, k, mb;
127 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
128 std::vector<dnnl_data_type_t> dt;
129 std::string stag, wtag, dtag;
130 vdims_t strides;
131 dnnl_data_type_t bia_dt;
132 int bia_mask;
133 std::vector<dims_mask_t> rt_dims_masks;
134
135 attr_t attr;
136 thr_ctx_t ctx_init, ctx_exe;
137
138 double ops;
139 float *src_scales, *wei_scales, *dst_scales;
140 int32_t *src_zp, *dst_zp;
141
142 const dims_t &src_dims() const { return vdims[0]; }
143 const dims_t &weights_dims() const { return vdims[1]; }
144 // const dims_t &prb_vdims_t::dst_dims() const;
145
146 const dims_mask_t &src_runtime_dim_mask() const { return rt_dims_masks[0]; }
147 const dims_mask_t &weights_runtime_dim_mask() const {
148 return rt_dims_masks[1];
149 }
150 const dims_mask_t &dst_runtime_dim_mask() const { return rt_dims_masks[2]; }
151
152 int src_broadcast_mask() const {
153 return prb_vdims_t::get_broadcast_mask(0);
154 }
155 int weights_broadcast_mask() const {
156 return prb_vdims_t::get_broadcast_mask(1);
157 }
158 int bias_broadcast_mask() const { return bia_mask; }
159
160 dnnl_data_type_t src_dt() const { return dt[0]; }
161 dnnl_data_type_t wei_dt() const { return dt[1]; }
162 dnnl_data_type_t dst_dt() const { return dt[2]; }
163 dnnl_data_type_t get_dt(data_kind_t data_kind) const;
164
165 float *generate_scales(int arg) const;
166 int32_t *generate_zero_points(
167 int arg, const attr_t::zero_points_t &zero_points, int N) const;
168
169 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
170
171private:
172 void init_dst_rt_dims_mask() {
173 if (rt_dims_masks.size() > 2) return;
174
175 const auto &src_rt_dim_mask = src_runtime_dim_mask();
176 const auto &wei_rt_dim_mask = weights_runtime_dim_mask();
177 dims_mask_t dst_rt_dim_mask;
178
179 for (int i = 0; i < ndims - 2; ++i) {
180 dst_rt_dim_mask[i] = src_rt_dim_mask[i] || wei_rt_dim_mask[i];
181 }
182
183 // m, n mask
184 dst_rt_dim_mask[ndims - 2] = src_rt_dim_mask[ndims - 2];
185 dst_rt_dim_mask[ndims - 1] = wei_rt_dim_mask[ndims - 1];
186
187 rt_dims_masks.push_back(dst_rt_dim_mask);
188 }
189};
190std::ostream &operator<<(std::ostream &s, const prb_t &prb);
191
192/* some extra control parameters which shouldn't be placed in prb_t */
193
194std::string str2cfg(const char *str);
195
196struct perf_report_t : public base_perf_report_t {
197 perf_report_t(const prb_t *prb, const char *perf_template)
198 : base_perf_report_t(perf_template)
199 , p_(prb)
200 , stag_({normalize_tag(p_->stag, p_->ndims)})
201 , wtag_(normalize_tag(p_->wtag, p_->ndims))
202 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {}
203
204 void dump_desc(std::ostream &s) const override {
205 s << static_cast<const prb_vdims_t &>(*p_);
206 }
207
208 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
209
210 double ops() const override { return p_->ops; }
211 const std::vector<dnnl_data_type_t> *sdt() const override {
212 return &p_->dt;
213 }
214 const attr_t *attr() const override { return &p_->attr; }
215 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
216 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
217 const std::string *name() const override { return &p_->name; }
218 const std::vector<std::string> *stag() const override { return &stag_; }
219 const std::string *wtag() const override { return &wtag_; }
220 const std::string *dtag() const override { return &dtag_; }
221
222private:
223 const prb_t *p_;
224 std::vector<std::string> stag_;
225 std::string wtag_, dtag_;
226};
227
228struct cfg_t : public base_cfg_t {
229 cfg_t(const prb_t *prb, std::vector<data_kind_t> kinds) {
230 for (const auto kind : kinds) {
231 auto orig_data_type = prb->get_dt(kind);
232 auto data_type
233 = deduce_cfg_data_type(orig_data_type, prb->attr, kind);
234 cfg_entry_.push_back(cfg_entry_t(
235 kind, orig_data_type, data_type, get_cfg_map(kind)));
236 }
237 }
238
239 const cfg_entry_t::cfg_map_t &get_cfg_map(data_kind_t kind) const;
240
241 float get_density(const density_args_t &density_args) const override;
242};
243
244inline int64_t src_off_f(const prb_t *prb, int64_t mb, int64_t m, int64_t k) {
245 return (mb * prb->m + m) * prb->k + k;
246}
247
248inline int64_t wei_off_f(const prb_t *prb, int64_t mb, int64_t k, int64_t n) {
249 return (mb * prb->k + k) * prb->n + n;
250}
251
252inline int64_t dst_off_f(const prb_t *prb, int64_t mb, int64_t m, int64_t n) {
253 return (mb * prb->m + m) * prb->n + n;
254}
255
256void handle_legacy_cfg(
257 std::vector<dnnl_data_type_t> &dt, const std::string &cfg);
258
259void skip_unimplemented_prb(const prb_t *prb, res_t *res);
260void skip_invalid_prb(const prb_t *prb, res_t *res);
261void compute_ref(const prb_t *prb, const args_t &args,
262 dnnl_primitive_t prim_ref = nullptr);
263
264int doit(const prb_t *prb, res_t *res);
265
266int bench(int argc, char **argv);
267
268} // namespace matmul
269
270#endif
271