1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdio.h>
18#include <stdlib.h>
19
20#include <sstream>
21
22#include "dnnl_common.hpp"
23#include "utils/parser.hpp"
24
25#include "matmul/matmul.hpp"
26
27namespace matmul {
28
29void check_correctness(const settings_t &s, const settings_t &def) {
30 std::vector<std::pair<dnnl_data_type_t, int>> bia_cfg;
31 for (const auto &i_bia_dt : s.bia_dt) {
32 if (i_bia_dt == dnnl_data_type_undef) {
33 bia_cfg.emplace_back(i_bia_dt, 0);
34 continue;
35 }
36 for (const auto &i_bia_mask : s.bia_mask)
37 bia_cfg.emplace_back(i_bia_dt, i_bia_mask);
38 }
39
40 for_(const auto &i_dt_ : s.dt)
41 for_(const auto &i_cfg : s.cfg)
42 for_(const auto &i_stag : s.stag)
43 for_(const auto &i_wtag : s.wtag)
44 for_(const auto &i_dtag : s.dtag)
45 for_(const auto &i_strides : s.strides)
46 for_(const auto &i_rt_dims_masks : s.rt_dims_masks)
47 for_(const auto &i_scales : s.scales)
48 for_(const auto &i_zero_points : s.zero_points)
49 for_(const auto &i_post_ops : s.post_ops)
50 for_(const auto &i_scratchpad_mode : s.scratchpad_mode)
51 for_(const auto &i_ctx_init : s.ctx_init)
52 for_(const auto &i_ctx_exe : s.ctx_exe)
53 for_(const auto &i_fpmath_mode : s.fpmath_mode)
54 for (const auto &i_bia_cfg : bia_cfg) {
55 auto attr = settings_t::get_attr(i_scales, i_zero_points, i_post_ops,
56 i_scratchpad_mode, i_fpmath_mode);
57
58 const bool strided_input = !i_strides[STRIDES_SRC].empty()
59 || !i_strides[STRIDES_WEI].empty()
60 || !i_strides[STRIDES_DST].empty();
61 if (strided_input) {
62 const bool no_stride_with_tag
63 = IMPLICATION(i_stag != def.stag[0],
64 i_strides[STRIDES_SRC].empty())
65 && IMPLICATION(i_wtag != def.wtag[0],
66 i_strides[STRIDES_WEI].empty())
67 && IMPLICATION(i_dtag != def.dtag[0],
68 i_strides[STRIDES_DST].empty());
69
70 if (!no_stride_with_tag) {
71 fprintf(stderr,
72 "ERROR: matmul driver: both `strides` and `tag` knobs "
73 "can not be used with either of `src`, `wei`, and `dst`"
74 " tensors.\n"),
75 fflush(stderr);
76 SAFE_V(FAIL);
77 }
78 }
79
80 auto i_dt = i_dt_;
81 if (!i_cfg.empty()) {
82 if (i_dt.size() == 1 && i_dt[0] == dnnl_f32) {
83 handle_legacy_cfg(i_dt, i_cfg);
84 } else {
85 fprintf(stderr,
86 "ERROR: matmul driver: `dt` and `cfg` knobs are "
87 "incompatible with each other. Please specify only one "
88 "of them at a time.\n"),
89 fflush(stderr);
90 SAFE_V(FAIL);
91 }
92 }
93
94 static constexpr int n_inputs = 3;
95 if (i_dt.size() != 1 && i_dt.size() != n_inputs) {
96 fprintf(stderr,
97 "ERROR: matmul driver: `dt` option expects either a single "
98 "input or three inputs in SRC, WEI, DST order. Current "
99 "size is: \"%ld\"\n",
100 (long)i_dt.size()),
101 fflush(stderr);
102 SAFE_V(FAIL);
103 }
104
105 const prb_t prb(s.prb_vdims, i_dt, i_stag, i_wtag, i_dtag, i_strides,
106 i_bia_cfg.first, i_bia_cfg.second, i_rt_dims_masks, attr,
107 i_ctx_init, i_ctx_exe);
108 std::stringstream ss;
109 ss << prb;
110 const std::string cpp_pstr = ss.str();
111 const char *pstr = cpp_pstr.c_str();
112 BENCHDNN_PRINT(1, "run: %s\n", pstr);
113
114 res_t res {};
115 doit(&prb, &res);
116
117 parse_result(res, pstr);
118
119 if (is_bench_mode(PERF)) {
120 perf_report_t pr(&prb, s.perf_template);
121 pr.report(&res, pstr);
122 }
123 }
124}
125
126static const std::string help_bia_mask
127 = "UINT (Default: `2`)\n Specifies a bit-mask that indicates "
128 "which bias dimensions coincide with C matrix dimensions, when `1` "
129 "is on a correspondent dimension.\n";
130
131static const std::string help_runtime_dims_masks
132 = "UINT:UINT (Default: `0:0`)\n Specifies a bit-mask for "
133 "matrices A and B that indicates whether a dimension is "
134 "`DNNL_RUNTIME_DIM_VAL` if `1` on a correspondent dimension.\n";
135
136int bench(int argc, char **argv) {
137 driver_name = "matmul";
138 using namespace parser;
139 static settings_t s;
140 static const settings_t def {};
141 for (; argc > 0; --argc, ++argv) {
142 const bool parsed_options = parse_bench_settings(argv[0])
143 || parse_batch(bench, argv[0])
144 || parse_multi_dt(s.dt, def.dt, argv[0], "dt")
145 || parse_cfg(s.cfg, def.cfg, str2cfg, argv[0])
146 || parse_tag(s.stag, def.stag, argv[0], "stag")
147 || parse_tag(s.wtag, def.wtag, argv[0], "wtag")
148 || parse_tag(s.dtag, def.dtag, argv[0], "dtag")
149 || parse_strides(s.strides, def.strides, argv[0], "strides")
150 || parse_dt(s.bia_dt, def.bia_dt, argv[0], "bia_dt")
151 || parse_vector_option(s.bia_mask, def.bia_mask, atoi, argv[0],
152 "bia_mask", help_bia_mask)
153 || parse_multivector_option(s.rt_dims_masks, def.rt_dims_masks,
154 atoi, argv[0], "runtime_dims_masks",
155 help_runtime_dims_masks)
156 || parse_attr_scales(s.scales, argv[0])
157 || parse_attr_zero_points(s.zero_points, argv[0])
158 || parse_attr_post_ops(s.post_ops, argv[0])
159 || parse_attr_scratchpad_mode(
160 s.scratchpad_mode, def.scratchpad_mode, argv[0])
161 || parse_attr_fpmath_mode(
162 s.fpmath_mode, def.fpmath_mode, argv[0])
163 || parse_ctx_init(s.ctx_init, def.ctx_init, argv[0])
164 || parse_ctx_exe(s.ctx_exe, def.ctx_exe, argv[0])
165 || parse_perf_template(s.perf_template, s.perf_template_def,
166 s.perf_template_csv(), argv[0])
167 || parse_reset(s, argv[0]) || parse_help(argv[0]);
168 if (!parsed_options) {
169 catch_unknown_options(argv[0]);
170
171 parse_prb_vdims(s.prb_vdims, argv[0]);
172 check_correctness(s, def);
173 }
174 }
175
176 return parse_last_argument();
177}
178
179} // namespace matmul
180