1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdio.h>
18#include <stdlib.h>
19#include <string.h>
20
21#include <sstream>
22
23#include "dnnl_common.hpp"
24#include "utils/parser.hpp"
25
26#include "bnorm/bnorm.hpp"
27
28namespace bnorm {
29
30void check_correctness(const settings_t &s) {
31 for_(const auto &i_dir : s.dir)
32 for_(const auto &i_dt : s.dt)
33 for_(const auto &i_tag : s.tag)
34 for_(const auto &i_flags : s.flags)
35 for_(const auto &i_mb : s.mb)
36 for_(const auto &i_post_ops : s.post_ops)
37 for_(const auto &i_scratchpad_mode : s.scratchpad_mode)
38 for_(const auto &i_ctx_init : s.ctx_init)
39 for_(const auto &i_ctx_exe : s.ctx_exe)
40 for (auto i_inplace : s.inplace) {
41 auto attr = settings_t::get_attr(i_post_ops, i_scratchpad_mode);
42
43 const prb_t prb(s.desc, i_mb, i_dir, i_dt, i_tag, i_flags, i_inplace,
44 attr, i_ctx_init, i_ctx_exe, s.check_alg, s.debug_check_ws);
45 std::stringstream ss;
46 ss << prb;
47 const std::string cpp_pstr = ss.str();
48 const char *pstr = cpp_pstr.c_str();
49
50 if (s.pattern && !match_regex(pstr, s.pattern)) return;
51 BENCHDNN_PRINT(1, "run: %s\n", pstr);
52
53 res_t res {};
54 doit(&prb, &res);
55
56 parse_result(res, pstr);
57
58 if (is_bench_mode(PERF)) {
59 perf_report_t pr(&prb, s.perf_template);
60 pr.report(&res, pstr);
61 }
62 }
63}
64
65static const std::string help_flags
66 = "FLAGS (Default: not specified)\n Specifies normalization "
67 "flags. `FLAGS` values are:\n * `G` for global_stats.\n * `C` "
68 "for scale.\n * `H` for shift.\n * `R` for fuse_norm_relu.\n "
69 " * `A` for fuse_norm_add_relu.\n";
70
71static const std::string help_check_alg
72 = "CHECK_ALG\n Dev debug setting to validate output for different "
73 "inputs. Overrides driver's automatic choice.\n `CHECK_ALG` "
74 "values are `alg_0` or `alg_1`.\n";
75
76static const std::string help_debug_check_ws
77 = "BOOL (Default: `false`)\n Instructs the driver to validates "
78 "workspace correctness on forward prop kind when set to `true`.\n";
79
80int bench(int argc, char **argv) {
81 driver_name = "bnorm";
82 using namespace parser;
83 static settings_t s;
84 static const settings_t def {};
85 for (; argc > 0; --argc, ++argv) {
86 const bool parsed_options = parse_bench_settings(argv[0])
87 || parse_batch(bench, argv[0])
88 || parse_dir(s.dir, def.dir, argv[0])
89 || parse_dt(s.dt, def.dt, argv[0])
90 || parse_tag(s.tag, def.tag, argv[0])
91 || parse_vector_option(s.flags, def.flags, str2flags, argv[0],
92 "flags", help_flags)
93 || parse_single_value_option(s.check_alg, def.check_alg,
94 str2check_alg, argv[0], "check-alg", help_check_alg)
95 || parse_inplace(s.inplace, def.inplace, argv[0])
96 || parse_mb(s.mb, def.mb, argv[0])
97 || parse_single_value_option(s.debug_check_ws,
98 def.debug_check_ws, str2bool, argv[0], "debug-check-ws",
99 help_debug_check_ws)
100 || parse_attr_post_ops(s.post_ops, argv[0])
101 || parse_attr_scratchpad_mode(
102 s.scratchpad_mode, def.scratchpad_mode, argv[0])
103 || parse_ctx_init(s.ctx_init, def.ctx_init, argv[0])
104 || parse_ctx_exe(s.ctx_exe, def.ctx_exe, argv[0])
105 || parse_test_pattern_match(s.pattern, argv[0])
106 || parse_perf_template(s.perf_template, s.perf_template_def,
107 s.perf_template_csv(), argv[0])
108 || parse_reset(s, argv[0]) || parse_help(argv[0]);
109 if (!parsed_options) {
110 catch_unknown_options(argv[0]);
111
112 SAFE(str2desc(&s.desc, argv[0]), CRIT);
113 check_correctness(s);
114 }
115 }
116
117 return parse_last_argument();
118}
119
120} // namespace bnorm
121