1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | #include <sstream> |
19 | #include <string> |
20 | #include <utility> |
21 | |
22 | #include "reorder/reorder.hpp" |
23 | #include "utils/parser.hpp" |
24 | |
25 | namespace reorder { |
26 | |
27 | flag_t str2flag(const char *str) { |
28 | std::string s(str); |
29 | if (s.empty()) return std::make_pair(FLAG_NONE, 0); |
30 | |
31 | size_t start_pos = 0; |
32 | // format of single entry is `flag_bit:mask` |
33 | auto sub = parser::get_substr(s, start_pos, ':'); |
34 | std::transform(sub.begin(), sub.end(), sub.begin(), ::tolower); |
35 | |
36 | flag_bit_t flag = FLAG_NONE; |
37 | if (sub.compare("s8s8_comp" ) == 0) |
38 | flag = FLAG_S8S8_COMP; |
39 | else if (sub.compare("zp_comp" ) == 0) |
40 | flag = FLAG_ZP_COMP; |
41 | else { |
42 | assert(!"unknown flag" ); |
43 | SAFE_V(FAIL); |
44 | } |
45 | |
46 | int mask = std::stoi(parser::get_substr(s, start_pos, ':')); |
47 | if (mask < 0) { |
48 | fprintf(stderr, |
49 | "ERROR: reorder driver: `mask` should be non-negative.\n" ), |
50 | fflush(stderr); |
51 | SAFE_V(FAIL); |
52 | } |
53 | |
54 | return std::make_pair(flag, mask); |
55 | } |
56 | |
57 | std::string flag_name2str(flag_bit_t flag) { |
58 | if (flag == FLAG_S8S8_COMP) return "s8s8_comp" ; |
59 | if (flag == FLAG_ZP_COMP) return "zp_comp" ; |
60 | assert(!"unsupported flag" ); |
61 | return "" ; |
62 | } |
63 | |
64 | std::ostream &operator<<(std::ostream &s, const std::vector<flag_t> &oflag) { |
65 | if (oflag[0].first == FLAG_NONE) return s; |
66 | |
67 | const char *delim = "" ; |
68 | for (const auto &i_oflag : oflag) { |
69 | s << delim << flag_name2str(i_oflag.first) << ":" << i_oflag.second; |
70 | delim = "+" ; |
71 | } |
72 | return s; |
73 | } |
74 | |
75 | cross_engine_t str2cross_engine(const char *str) { |
76 | if (!strcasecmp("none" , str)) return NONE; |
77 | if (!strcasecmp("cpu2gpu" , str)) return CPU2GPU; |
78 | if (!strcasecmp("gpu2cpu" , str)) return GPU2CPU; |
79 | assert(!"unknown cross engine" ); |
80 | return NONE; |
81 | } |
82 | |
83 | const char *cross_engine2str(cross_engine_t cross_engine) { |
84 | switch (cross_engine) { |
85 | case NONE: return "none" ; |
86 | case CPU2GPU: return "cpu2gpu" ; |
87 | case GPU2CPU: return "gpu2cpu" ; |
88 | default: assert(!"unknown cross engine" ); return "unknown cross engine" ; |
89 | } |
90 | } |
91 | |
92 | bool prb_t::is_reorder_with_compensation(flag_bit_t flag) const { |
93 | if (oflag.empty()) return false; |
94 | |
95 | return std::any_of(oflag.cbegin(), oflag.cend(), |
96 | [&](const flag_t &oflag) { return (oflag.first & flag); }); |
97 | } |
98 | |
99 | void prb_t::get_compensation_parameters( |
100 | dims_t &comp_dims, int &mask, flag_bit_t flag) const { |
101 | if (is_reorder_with_compensation(flag)) { |
102 | for (const auto &i_oflag : oflag) { |
103 | if (i_oflag.first != flag) continue; |
104 | |
105 | mask = i_oflag.second; |
106 | for (int d = 0; d < ndims; ++d) |
107 | if (mask & (1 << d)) comp_dims.push_back(dims[d]); |
108 | } |
109 | } |
110 | } |
111 | |
112 | dims_t prb_t::get_compensation_dims(flag_bit_t flag) const { |
113 | dims_t comp_dims; |
114 | int mask = 0; |
115 | get_compensation_parameters(comp_dims, mask, flag); |
116 | return comp_dims; |
117 | } |
118 | |
119 | int prb_t::get_compensation_mask(flag_bit_t flag) const { |
120 | dims_t comp_dims; |
121 | int mask = 0; |
122 | get_compensation_parameters(comp_dims, mask, flag); |
123 | return mask; |
124 | } |
125 | |
126 | float *prb_t::generate_scales(int arg) const { |
127 | const auto &scales = attr.scales; |
128 | if (scales.is_def()) return nullptr; |
129 | |
130 | const auto &e = scales.get(arg); |
131 | const int mask = attr_t::get_default_mask(e.policy); |
132 | int64_t uniq_scales = nelems(mask); |
133 | |
134 | float *values = (float *)zmalloc(sizeof(float) * uniq_scales, 64); |
135 | SAFE_V(values != nullptr ? OK : FAIL); |
136 | for (int d = 0; d < uniq_scales; ++d) |
137 | values[d] = e.scale; |
138 | if (uniq_scales > 1) values[uniq_scales - 1] /= 2.f; |
139 | return values; |
140 | } |
141 | |
142 | int32_t *prb_t::generate_zero_points(int arg) const { |
143 | const attr_t::zero_points_t &zero_points = this->attr.zero_points; |
144 | if (zero_points.is_def(arg)) return nullptr; |
145 | |
146 | const auto &e = zero_points.get(arg); |
147 | assert(e.policy == policy_t::COMMON); |
148 | |
149 | int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4); |
150 | SAFE_V(zp != nullptr ? OK : FAIL); |
151 | zp[0] = e.value; |
152 | return zp; |
153 | } |
154 | |
155 | dt_conf_t prb_t::get_conf(data_kind_t kind) const { |
156 | switch (kind) { |
157 | case SRC: return dt2cfg(sdt); |
158 | case DST: return dt2cfg(ddt); |
159 | default: assert(!"unexpected data kind!" ); SAFE_V(FAIL); |
160 | } |
161 | return dt2cfg(dnnl_f32); |
162 | } |
163 | |
164 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
165 | dump_global_params(s); |
166 | settings_t def; |
167 | |
168 | s << "--sdt=" << prb.sdt << " " ; |
169 | s << "--ddt=" << prb.ddt << " " ; |
170 | s << "--stag=" << prb.stag << " " ; |
171 | s << "--dtag=" << prb.dtag << " " ; |
172 | |
173 | if (canonical || (!prb.oflag.empty() && prb.oflag != def.oflag[0])) |
174 | s << "--oflag=" << prb.oflag << " " ; |
175 | if (canonical || prb.cross_engine != def.cross_engine[0]) |
176 | s << "--cross-engine=" << cross_engine2str(prb.cross_engine) << " " ; |
177 | if (canonical || prb.runtime_dim_mask != def.runtime_dim_mask[0]) |
178 | s << "--runtime-dim-mask=" << prb.runtime_dim_mask << " " ; |
179 | |
180 | s << prb.attr; |
181 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
182 | s << "--ctx-init=" << prb.ctx_init << " " ; |
183 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
184 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
185 | |
186 | s << static_cast<prb_dims_t>(prb); |
187 | |
188 | return s; |
189 | } |
190 | |
191 | } // namespace reorder |
192 | |