1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <string.h> |
18 | |
19 | #include <sstream> |
20 | |
21 | #include "dnnl_common.hpp" |
22 | #include "utils/parser.hpp" |
23 | |
24 | #include "reorder.hpp" |
25 | |
26 | namespace reorder { |
27 | |
28 | void check_correctness(const settings_t &s) { |
29 | for_(const auto &i_sdt : s.sdt) |
30 | for_(const auto &i_ddt : s.ddt) |
31 | for_(const auto &i_stag : s.stag) |
32 | for_(const auto &i_dtag : s.dtag) |
33 | for_(const auto &i_oflag : s.oflag) |
34 | for_(const auto &i_cross_engine : s.cross_engine) |
35 | for_(const auto &i_scales : s.scales) |
36 | for_(const auto &i_zero_points : s.zero_points) |
37 | for_(const auto &i_post_ops : s.post_ops) |
38 | for_(const auto &i_scratchpad_mode : s.scratchpad_mode) |
39 | for_(const auto &i_ctx_init : s.ctx_init) |
40 | for_(const auto &i_ctx_exe : s.ctx_exe) |
41 | for (auto i_runtime_dim_mask : s.runtime_dim_mask) { |
42 | const auto &src_scale = i_scales.get(DNNL_ARG_SRC); |
43 | const std::vector<float> src_test_scales = src_scale.scale == 0 |
44 | ? s.def_scale |
45 | : std::vector<float>(1, src_scale.scale); |
46 | const auto &dst_scale = i_scales.get(DNNL_ARG_DST); |
47 | const std::vector<float> dst_test_scales = dst_scale.scale == 0 |
48 | ? s.def_scale |
49 | : std::vector<float>(1, dst_scale.scale); |
50 | |
51 | for_(const auto &i_src_test_scale : src_test_scales) |
52 | for (const auto &i_dst_test_scale : dst_test_scales) { |
53 | attr_t::arg_scales_t test_arg_scales; |
54 | test_arg_scales.set(DNNL_ARG_SRC, |
55 | {src_scale.policy, i_src_test_scale, src_scale.runtime}); |
56 | test_arg_scales.set(DNNL_ARG_DST, |
57 | {dst_scale.policy, i_dst_test_scale, dst_scale.runtime}); |
58 | auto attr = settings_t::get_attr(test_arg_scales, i_zero_points, |
59 | i_post_ops, i_scratchpad_mode); |
60 | |
61 | const prb_t prb(s.prb_dims, i_sdt, i_ddt, i_stag, i_dtag, attr, |
62 | i_ctx_init, i_ctx_exe, i_oflag, i_cross_engine, |
63 | i_runtime_dim_mask); |
64 | std::stringstream ss; |
65 | ss << prb; |
66 | const std::string cpp_pstr = ss.str(); |
67 | const char *pstr = cpp_pstr.c_str(); |
68 | BENCHDNN_PRINT(1, "run: %s\n" , pstr); |
69 | |
70 | res_t res {}; |
71 | doit(&prb, &res); |
72 | |
73 | parse_result(res, pstr); |
74 | |
75 | if (is_bench_mode(PERF)) { |
76 | perf_report_t pr(&prb, s.perf_template); |
77 | pr.report(&res, pstr); |
78 | } |
79 | } |
80 | } |
81 | } |
82 | |
83 | int verify_input(const settings_t &s) { |
84 | for_(const auto &i_scales : s.scales) |
85 | for (auto arg : {DNNL_ARG_SRC, DNNL_ARG_DST}) { |
86 | if (i_scales.get(arg).policy == policy_t::PER_OC) { |
87 | BENCHDNN_PRINT(0, "%s\n" , |
88 | "ERROR: `per_oc` policy is not supported due to " |
89 | "potential ambiguity. Please use one of `per_dim_0` or " |
90 | "`per_dim_1` policies." ); |
91 | return FAIL; |
92 | } |
93 | } |
94 | |
95 | for (const auto &i_cross_engine : s.cross_engine) { |
96 | if (i_cross_engine != NONE && is_cpu()) { |
97 | BENCHDNN_PRINT(0, "%s\n" , |
98 | "ERROR: `cpu` engine does not support anything but " |
99 | "`--cross-engine=none`." ); |
100 | return FAIL; |
101 | } |
102 | } |
103 | |
104 | return OK; |
105 | } |
106 | |
107 | static const std::string help_oflag |
108 | = "FLAG:MASK[+...] (Default: not specified)\n Specifies `extra` " |
109 | "field of destination memory descriptor.\n `FLAG` values are " |
110 | "`s8s8_comp` and `zp_comp`.\n `MASK` is an non-negative integer " |
111 | "specifying dimension to apply compensation.\n" ; |
112 | |
113 | static const std::string help_runtime_dim_mask |
114 | = "UINT (Default: `0`)\n Specifies a bit-mask that indicates " |
115 | "whether a dimension is `DNNL_RUNTIME_DIM_VAL` if `1` on a " |
116 | "correspondent dimension.\n" ; |
117 | |
118 | static const std::string help_def_scales |
119 | = "FLOAT\n Scales, used to improve testing coverage.\n If " |
120 | "`--attr-scales` is specified, does not have an effect.\n" ; |
121 | |
122 | static const std::string help_cross_engine |
123 | = "KIND (Default: `none`)\n Specifies `KIND` of cross-engine " |
124 | "used for benchmarking.\n `KIND` values are `none`, `cpu2gpu` or " |
125 | "`gpu2cpu`.\n" ; |
126 | |
127 | int bench(int argc, char **argv) { |
128 | driver_name = "reorder" ; |
129 | using namespace parser; |
130 | static settings_t s; |
131 | static const settings_t def {}; |
132 | for (; argc > 0; --argc, ++argv) { |
133 | const bool parsed_options = parse_bench_settings(argv[0]) |
134 | || parse_batch(bench, argv[0]) |
135 | || parse_dt(s.sdt, def.sdt, argv[0], "sdt" ) |
136 | || parse_dt(s.ddt, def.ddt, argv[0], "ddt" ) |
137 | || parse_tag(s.stag, def.stag, argv[0], "stag" ) |
138 | || parse_tag(s.dtag, def.dtag, argv[0], "dtag" ) |
139 | || parse_multivector_option(s.oflag, def.oflag, str2flag, |
140 | argv[0], "oflag" , help_oflag, ',', '+') |
141 | || parse_vector_option(s.runtime_dim_mask, def.runtime_dim_mask, |
142 | atoi, argv[0], "runtime-dim-mask" , |
143 | help_runtime_dim_mask) |
144 | || parse_vector_option(s.def_scale, def.def_scale, atof, |
145 | argv[0], "def-scales" , help_def_scales) |
146 | || parse_vector_option(s.cross_engine, def.cross_engine, |
147 | str2cross_engine, argv[0], "cross-engine" , |
148 | help_cross_engine) |
149 | || parse_attr_scales(s.scales, argv[0]) |
150 | || parse_attr_zero_points(s.zero_points, argv[0]) |
151 | || parse_attr_post_ops(s.post_ops, argv[0]) |
152 | || parse_attr_scratchpad_mode( |
153 | s.scratchpad_mode, def.scratchpad_mode, argv[0]) |
154 | || parse_ctx_init(s.ctx_init, def.ctx_init, argv[0]) |
155 | || parse_ctx_exe(s.ctx_exe, def.ctx_exe, argv[0]) |
156 | || parse_perf_template(s.perf_template, s.perf_template_def, |
157 | s.perf_template_csv(), argv[0]) |
158 | || parse_reset(s, argv[0]) || parse_help(argv[0]); |
159 | if (!parsed_options) { |
160 | catch_unknown_options(argv[0]); |
161 | |
162 | parse_prb_dims(s.prb_dims, argv[0]); |
163 | |
164 | SAFE(verify_input(s), WARN); |
165 | check_correctness(s); |
166 | } |
167 | } |
168 | |
169 | return parse_last_argument(); |
170 | } |
171 | |
172 | } // namespace reorder |
173 | |