1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <sstream>
18
19#include "utils/parser.hpp"
20
21#include "reduction/reduction.hpp"
22
23namespace reduction {
24
25void check_correctness(const settings_t &s) {
26 for_(const auto &i_sdt : s.sdt)
27 for_(const auto &i_ddt : s.ddt)
28 for_(const auto &i_stag : s.stag)
29 for_(const auto &i_dtag : s.dtag)
30 for_(const auto &i_post_ops : s.post_ops)
31 for_(const auto &i_alg : s.alg)
32 for_(const auto &i_p : s.p)
33 for_(const auto &i_eps : s.eps)
34 for_(const auto &i_scratchpad_mode : s.scratchpad_mode)
35 for_(const auto &i_ctx_init : s.ctx_init)
36 for (const auto &i_ctx_exe : s.ctx_exe) {
37 // Expect exactly two inputs for problem dimensions.
38 static constexpr int n_inputs = 2;
39 if (s.prb_vdims.n_inputs() != n_inputs) {
40 BENCHDNN_PRINT(0, "%s\n",
41 "Error: input tensors were specified in wrong format. "
42 "Please use NxNxNxNxN:MxMxMxMxM as a problem description "
43 "format.");
44 SAFE_V(FAIL);
45 }
46
47 auto attr = settings_t::get_attr(i_post_ops, i_scratchpad_mode);
48
49 const prb_t prb(s.prb_vdims, i_sdt, i_ddt, i_stag, i_dtag, i_alg, i_p,
50 i_eps, attr, i_ctx_init, i_ctx_exe);
51 std::stringstream ss;
52 ss << prb;
53 const std::string cpp_pstr = ss.str();
54 const char *pstr = cpp_pstr.c_str();
55 BENCHDNN_PRINT(1, "run: %s\n", pstr);
56
57 res_t res {};
58 doit(&prb, &res);
59
60 parse_result(res, pstr);
61
62 if (is_bench_mode(PERF)) {
63 perf_report_t pr(&prb, s.perf_template);
64 pr.report(&res, pstr);
65 }
66 }
67}
68
69static const std::string help_p
70 = "FLOAT (Default: `1.f`)\n Specifies algorithm parameter "
71 "extension where applicable.\n";
72
73static const std::string help_eps
74 = "FLOAT (Default: `0.f`)\n Specifies algorithm parameter "
75 "extension where applicable.\n";
76
77int bench(int argc, char **argv) {
78 driver_name = "reduction";
79 using namespace parser;
80 static settings_t s;
81 static const settings_t def {};
82 for (; argc > 0; --argc, ++argv) {
83 const bool parsed_options = parse_bench_settings(argv[0])
84 || parse_batch(bench, argv[0])
85 || parse_dt(s.sdt, def.sdt, argv[0], "sdt")
86 || parse_dt(s.ddt, def.ddt, argv[0], "ddt")
87 || parse_tag(s.stag, def.stag, argv[0], "stag")
88 || parse_tag(s.dtag, def.dtag, argv[0], "dtag")
89 || parse_alg(s.alg, def.alg, str2alg, argv[0])
90 || parse_vector_option(s.p, def.p, atof, argv[0], "p", help_p)
91 || parse_vector_option(
92 s.eps, def.eps, atof, argv[0], "eps", help_eps)
93 || parse_attr_post_ops(s.post_ops, argv[0])
94 || parse_ctx_init(s.ctx_init, def.ctx_init, argv[0])
95 || parse_ctx_exe(s.ctx_exe, def.ctx_exe, argv[0])
96 || parse_perf_template(s.perf_template, s.perf_template_def,
97 s.perf_template_csv(), argv[0])
98 || parse_reset(s, argv[0]) || parse_help(argv[0]);
99 if (!parsed_options) {
100 catch_unknown_options(argv[0]);
101
102 parse_prb_vdims(s.prb_vdims, argv[0]);
103
104 check_correctness(s);
105 }
106 }
107
108 return parse_last_argument();
109}
110
111} // namespace reduction
112