1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18#include <sstream>
19#include <string>
20#include <utility>
21
22#include "reorder/reorder.hpp"
23#include "utils/parser.hpp"
24
25namespace reorder {
26
27flag_t str2flag(const char *str) {
28 std::string s(str);
29 if (s.empty()) return std::make_pair(FLAG_NONE, 0);
30
31 size_t start_pos = 0;
32 // format of single entry is `flag_bit:mask`
33 auto sub = parser::get_substr(s, start_pos, ':');
34 std::transform(sub.begin(), sub.end(), sub.begin(), ::tolower);
35
36 flag_bit_t flag = FLAG_NONE;
37 if (sub.compare("s8s8_comp") == 0)
38 flag = FLAG_S8S8_COMP;
39 else if (sub.compare("zp_comp") == 0)
40 flag = FLAG_ZP_COMP;
41 else {
42 assert(!"unknown flag");
43 SAFE_V(FAIL);
44 }
45
46 int mask = std::stoi(parser::get_substr(s, start_pos, ':'));
47 if (mask < 0) {
48 fprintf(stderr,
49 "ERROR: reorder driver: `mask` should be non-negative.\n"),
50 fflush(stderr);
51 SAFE_V(FAIL);
52 }
53
54 return std::make_pair(flag, mask);
55}
56
57std::string flag_name2str(flag_bit_t flag) {
58 if (flag == FLAG_S8S8_COMP) return "s8s8_comp";
59 if (flag == FLAG_ZP_COMP) return "zp_comp";
60 assert(!"unsupported flag");
61 return "";
62}
63
64std::ostream &operator<<(std::ostream &s, const std::vector<flag_t> &oflag) {
65 if (oflag[0].first == FLAG_NONE) return s;
66
67 const char *delim = "";
68 for (const auto &i_oflag : oflag) {
69 s << delim << flag_name2str(i_oflag.first) << ":" << i_oflag.second;
70 delim = "+";
71 }
72 return s;
73}
74
75cross_engine_t str2cross_engine(const char *str) {
76 if (!strcasecmp("none", str)) return NONE;
77 if (!strcasecmp("cpu2gpu", str)) return CPU2GPU;
78 if (!strcasecmp("gpu2cpu", str)) return GPU2CPU;
79 assert(!"unknown cross engine");
80 return NONE;
81}
82
83const char *cross_engine2str(cross_engine_t cross_engine) {
84 switch (cross_engine) {
85 case NONE: return "none";
86 case CPU2GPU: return "cpu2gpu";
87 case GPU2CPU: return "gpu2cpu";
88 default: assert(!"unknown cross engine"); return "unknown cross engine";
89 }
90}
91
92bool prb_t::is_reorder_with_compensation(flag_bit_t flag) const {
93 if (oflag.empty()) return false;
94
95 return std::any_of(oflag.cbegin(), oflag.cend(),
96 [&](const flag_t &oflag) { return (oflag.first & flag); });
97}
98
99void prb_t::get_compensation_parameters(
100 dims_t &comp_dims, int &mask, flag_bit_t flag) const {
101 if (is_reorder_with_compensation(flag)) {
102 for (const auto &i_oflag : oflag) {
103 if (i_oflag.first != flag) continue;
104
105 mask = i_oflag.second;
106 for (int d = 0; d < ndims; ++d)
107 if (mask & (1 << d)) comp_dims.push_back(dims[d]);
108 }
109 }
110}
111
112dims_t prb_t::get_compensation_dims(flag_bit_t flag) const {
113 dims_t comp_dims;
114 int mask = 0;
115 get_compensation_parameters(comp_dims, mask, flag);
116 return comp_dims;
117}
118
119int prb_t::get_compensation_mask(flag_bit_t flag) const {
120 dims_t comp_dims;
121 int mask = 0;
122 get_compensation_parameters(comp_dims, mask, flag);
123 return mask;
124}
125
126float *prb_t::generate_scales(int arg) const {
127 const auto &scales = attr.scales;
128 if (scales.is_def()) return nullptr;
129
130 const auto &e = scales.get(arg);
131 const int mask = attr_t::get_default_mask(e.policy);
132 int64_t uniq_scales = nelems(mask);
133
134 float *values = (float *)zmalloc(sizeof(float) * uniq_scales, 64);
135 SAFE_V(values != nullptr ? OK : FAIL);
136 for (int d = 0; d < uniq_scales; ++d)
137 values[d] = e.scale;
138 if (uniq_scales > 1) values[uniq_scales - 1] /= 2.f;
139 return values;
140}
141
142int32_t *prb_t::generate_zero_points(int arg) const {
143 const attr_t::zero_points_t &zero_points = this->attr.zero_points;
144 if (zero_points.is_def(arg)) return nullptr;
145
146 const auto &e = zero_points.get(arg);
147 assert(e.policy == policy_t::COMMON);
148
149 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4);
150 SAFE_V(zp != nullptr ? OK : FAIL);
151 zp[0] = e.value;
152 return zp;
153}
154
155dt_conf_t prb_t::get_conf(data_kind_t kind) const {
156 switch (kind) {
157 case SRC: return dt2cfg(sdt);
158 case DST: return dt2cfg(ddt);
159 default: assert(!"unexpected data kind!"); SAFE_V(FAIL);
160 }
161 return dt2cfg(dnnl_f32);
162}
163
164std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
165 dump_global_params(s);
166 settings_t def;
167
168 s << "--sdt=" << prb.sdt << " ";
169 s << "--ddt=" << prb.ddt << " ";
170 s << "--stag=" << prb.stag << " ";
171 s << "--dtag=" << prb.dtag << " ";
172
173 if (canonical || (!prb.oflag.empty() && prb.oflag != def.oflag[0]))
174 s << "--oflag=" << prb.oflag << " ";
175 if (canonical || prb.cross_engine != def.cross_engine[0])
176 s << "--cross-engine=" << cross_engine2str(prb.cross_engine) << " ";
177 if (canonical || prb.runtime_dim_mask != def.runtime_dim_mask[0])
178 s << "--runtime-dim-mask=" << prb.runtime_dim_mask << " ";
179
180 s << prb.attr;
181 if (canonical || prb.ctx_init != def.ctx_init[0])
182 s << "--ctx-init=" << prb.ctx_init << " ";
183 if (canonical || prb.ctx_exe != def.ctx_exe[0])
184 s << "--ctx-exe=" << prb.ctx_exe << " ";
185
186 s << static_cast<prb_dims_t>(prb);
187
188 return s;
189}
190
191} // namespace reorder
192