1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef BRGEMM_HPP
18#define BRGEMM_HPP
19
20#include <algorithm>
21#include <bitset>
22#include <iostream>
23#include <map>
24#include <numeric>
25
26#include "oneapi/dnnl/dnnl.h"
27
28#if defined(DNNL_X64) && DNNL_X64 == 1 \
29 && (DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE)
30#include "src/cpu/x64/brgemm/brgemm.hpp"
31#endif
32
33#include "common.hpp"
34#include "dnnl_common.hpp"
35#include "utils/cfg.hpp"
36#include "utils/perf_report.hpp"
37#include "utils/settings.hpp"
38
39namespace brgemm {
40
41struct settings_t : public base_settings_t {
42 settings_t() = default;
43
44 // ctor to save certain fields from resetting
45 settings_t(const char *perf_template) : settings_t() {
46 this->perf_template = perf_template;
47 }
48
49 prb_vdims_t prb_vdims;
50
51 std::vector<std::vector<dnnl_data_type_t>> dt {{dnnl_f32}};
52 std::vector<std::string> stag {tag::abx}, wtag {tag::undef},
53 dtag {tag::abx};
54 std::vector<std::vector<int64_t>> ld {{}};
55 std::vector<dnnl_data_type_t> bia_dt {dnnl_data_type_undef};
56 std::vector<int> batch_size {1};
57 std::vector<float> alpha {1.f}, beta {0.f};
58 std::vector<std::string> brgemm_attr {std::string()};
59
60 const char *perf_template_csv() const {
61 static const std::string args = "";
62 return perf_template_csv_base(args);
63 }
64
65 void reset() { *this = settings_t(perf_template); }
66};
67
68struct prb_t : public prb_vdims_t {
69 prb_t(const prb_vdims_t &prb_vdims, const std::vector<dnnl_data_type_t> &dt,
70 const std::string &stag, const std::string &wtag,
71 const std::string &dtag, const std::vector<int64_t> &ld,
72 dnnl_data_type_t bia_dt, float alpha, float beta, int batch_size,
73 const std::string &brgemm_attr, const attr_t &attr,
74 const thr_ctx_t &ctx_init, const thr_ctx_t &ctx_exe)
75 : prb_vdims_t(prb_vdims)
76 , dt(dt)
77 , stag(stag)
78 , wtag(wtag)
79 , dtag(dtag)
80 , ld(ld)
81 , bia_dt(bia_dt)
82 , alpha(alpha)
83 , beta(beta)
84 , batch_size(batch_size)
85 , brgemm_attr(brgemm_attr)
86 , attr(attr)
87 , ctx_init(ctx_init)
88 , ctx_exe(ctx_exe)
89 , scales(NULL) {
90
91 // Broadcast data types if needed
92 if (dt.size() == 1) {
93 const auto val = dt[0]; // Need a copy here.
94 this->dt.assign(3, val);
95 }
96
97 const auto &srcdims = src_dims();
98 const auto &weidims = weights_dims();
99 m = srcdims[ndims - 2];
100 k = srcdims.back();
101 n = weidims.back();
102 dst_dims[ndims - 2] = m;
103 dst_dims[ndims - 1] = n;
104
105 const auto nelems = std::accumulate(dst_dims.begin(), dst_dims.end(),
106 (dnnl_dim_t)1, std::multiplies<dnnl_dim_t>());
107 ops = 2. * nelems * k;
108
109 generate_oscales();
110 src_zp = generate_zero_points(DNNL_ARG_SRC, attr.zero_points, k);
111 dst_zp = generate_zero_points(DNNL_ARG_DST, attr.zero_points, n);
112 }
113 ~prb_t() {
114 if (scales) zfree(scales);
115 if (src_zp) zfree(src_zp);
116 if (dst_zp) zfree(dst_zp);
117 }
118
119 int m, n, k;
120 dir_t dir = FLAG_FWD; // Lack of prop_kind, always considered as forward.
121 std::vector<dnnl_data_type_t> dt;
122 std::string stag, wtag, dtag;
123 std::vector<int64_t> ld;
124 dnnl_data_type_t bia_dt;
125 float alpha, beta;
126 int batch_size;
127 std::string brgemm_attr;
128
129 attr_t attr;
130 thr_ctx_t ctx_init, ctx_exe;
131
132 double ops;
133 float *scales;
134 int32_t *src_zp, *dst_zp;
135
136 const dims_t &src_dims() const { return vdims[0]; }
137 const dims_t &weights_dims() const { return vdims[1]; }
138 // const dims_t &prb_vdims_t::dst_dims() const;
139
140 dnnl_data_type_t src_dt() const { return dt[0]; }
141 dnnl_data_type_t wei_dt() const { return dt[1]; }
142 dnnl_data_type_t acc_dt() const {
143 return is_integral_dt(wei_dt()) ? dnnl_s32 : dnnl_f32;
144 }
145 dnnl_data_type_t dst_dt() const { return dt[2]; }
146 dnnl_data_type_t get_dt(data_kind_t data_kind) const;
147
148 int64_t get_lda() const {
149 if (!ld.empty() && ld[0] != 0) {
150 assert(ld[0] >= batch_size * k);
151 return ld[0];
152 }
153 return batch_size * k;
154 }
155 int64_t get_ldb() const {
156 if (!ld.empty() && ld[1] != 0) {
157 assert(ld[1] >= n);
158 return ld[1];
159 }
160 return n;
161 }
162 int64_t get_ldc(bool use_dst_as_acc) const {
163 if (use_dst_as_acc) return get_ldd();
164 return n;
165 }
166 int64_t get_ldd() const {
167 if (!ld.empty() && ld[2] != 0) {
168 assert(ld[2] >= n);
169 return ld[2];
170 }
171 return n;
172 }
173
174 void generate_oscales();
175 int32_t *generate_zero_points(
176 int arg, const attr_t::zero_points_t &zero_points, int N);
177
178 BENCHDNN_DISALLOW_COPY_AND_ASSIGN(prb_t);
179};
180std::ostream &operator<<(std::ostream &s, const prb_t &prb);
181
182// TODO: not supported as of now.
183struct perf_report_t : public base_perf_report_t {
184 perf_report_t(const prb_t *prb, const char *perf_template)
185 : base_perf_report_t(perf_template)
186 , p_(prb)
187 , stag_({normalize_tag(p_->stag, p_->ndims)})
188 , wtag_(normalize_tag(p_->wtag, p_->ndims))
189 , dtag_(normalize_tag(p_->dtag, p_->ndims)) {}
190
191 void dump_desc(std::ostream &s) const override {
192 s << static_cast<const prb_vdims_t &>(*p_);
193 }
194
195 void dump_desc_csv(std::ostream &s) const override { dump_desc(s); }
196
197 double ops() const override { return p_->ops; }
198 const std::vector<dnnl_data_type_t> *sdt() const override {
199 return &p_->dt;
200 }
201 const attr_t *attr() const override { return &p_->attr; }
202 const thr_ctx_t *ctx_init() const override { return &p_->ctx_init; }
203 const thr_ctx_t *ctx_exe() const override { return &p_->ctx_exe; }
204 const std::string *name() const override { return &p_->name; }
205 const std::vector<std::string> *stag() const override { return &stag_; }
206 const std::string *wtag() const override { return &wtag_; }
207 const std::string *dtag() const override { return &dtag_; }
208
209private:
210 const prb_t *p_;
211 std::vector<std::string> stag_;
212 std::string wtag_, dtag_;
213};
214
215struct cfg_t : public base_cfg_t {
216 cfg_t(const prb_t *prb, std::vector<data_kind_t> kinds) {
217 for (const auto kind : kinds) {
218 auto orig_data_type = prb->get_dt(kind);
219 auto data_type
220 = deduce_cfg_data_type(orig_data_type, prb->attr, kind);
221 cfg_entry_.push_back(cfg_entry_t(
222 kind, orig_data_type, data_type, get_cfg_map(kind)));
223 }
224 }
225
226 const cfg_entry_t::cfg_map_t &get_cfg_map(data_kind_t kind) const;
227
228 float get_density(const density_args_t &density_args) const override;
229};
230
231void skip_unimplemented_prb(const prb_t *prb, res_t *res);
232void skip_invalid_prb(const prb_t *prb, res_t *res);
233void compute_ref(const prb_t *prb, const args_t &args,
234 dnnl_primitive_t prim_ref = nullptr);
235
236int doit(const prb_t *prb, res_t *res);
237
238int bench(int argc, char **argv);
239
240} // namespace brgemm
241
242#endif
243