1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | #include "oneapi/dnnl/dnnl.h" |
24 | |
25 | #include "dnn_types.hpp" |
26 | #include "dnnl_common.hpp" |
27 | |
28 | #include "conv/conv.hpp" |
29 | |
30 | namespace conv { |
31 | |
32 | alg_t str2alg(const char *str) { |
33 | #define CASE(_alg) \ |
34 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
35 | CASE(AUTO); |
36 | CASE(convolution_auto); |
37 | CASE(DIRECT); |
38 | CASE(convolution_direct); |
39 | CASE(WINO); |
40 | CASE(convolution_wino); |
41 | #undef CASE |
42 | assert(!"unknown algorithm" ); |
43 | return UNDEF; |
44 | } |
45 | |
46 | const char *alg2str(alg_t alg) { |
47 | if (alg == AUTO) return "auto" ; |
48 | if (alg == DIRECT) return "direct" ; |
49 | if (alg == WINO) return "wino" ; |
50 | assert(!"unknown algorithm" ); |
51 | return "undef" ; |
52 | } |
53 | |
54 | alg_t alg_kind2alg(dnnl_alg_kind_t alg) { |
55 | if (alg == dnnl_convolution_auto) return AUTO; |
56 | if (alg == dnnl_convolution_direct) return DIRECT; |
57 | if (alg == dnnl_convolution_winograd) return WINO; |
58 | assert(!"unknown algorithm" ); |
59 | return DIRECT; |
60 | } |
61 | |
62 | int str2desc(desc_t *desc, const char *str) { |
63 | /* canonical form: |
64 | * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS |
65 | * |
66 | * where X is number, S - string |
67 | * note: symbol `_` is ignored |
68 | * |
69 | * implicit rules: |
70 | * - if smaller dimensions are not specified => square or cubic form; |
71 | * - if output is undefined => compute output; |
72 | * - if padding is undefined => compute trivial padding; |
73 | */ |
74 | |
75 | desc_t d {0}; |
76 | d.g = 1; |
77 | d.mb = 2; |
78 | d.sd = d.sh = d.sw = 1; |
79 | d.pd = d.ph = d.pw = -1; |
80 | |
81 | const char *s = str; |
82 | assert(s); |
83 | |
84 | #define CASE_NN(prb, c) \ |
85 | do { \ |
86 | if (!strncmp(prb, s, strlen(prb))) { \ |
87 | ok = 1; \ |
88 | s += strlen(prb); \ |
89 | char *end_s; \ |
90 | d.c = strtol(s, &end_s, 10); \ |
91 | s += (end_s - s); \ |
92 | /* check any # groups, including one, works correctly */ \ |
93 | if (!strncmp(prb, "g", 1)) d.has_groups = true; \ |
94 | if (d.c < 0) return FAIL; \ |
95 | /* printf("@@@debug: %s: %d\n", prb, d. c); */ \ |
96 | } \ |
97 | } while (0) |
98 | #define CASE_N(c) CASE_NN(#c, c) |
99 | while (*s) { |
100 | int ok = 0; |
101 | CASE_N(g); |
102 | CASE_N(mb); |
103 | CASE_N(ic); |
104 | CASE_N(id); |
105 | CASE_N(ih); |
106 | CASE_N(iw); |
107 | CASE_N(oc); |
108 | CASE_N(od); |
109 | CASE_N(oh); |
110 | CASE_N(ow); |
111 | CASE_N(kd); |
112 | CASE_N(kh); |
113 | CASE_N(kw); |
114 | CASE_N(sd); |
115 | CASE_N(sh); |
116 | CASE_N(sw); |
117 | CASE_N(pd); |
118 | CASE_N(ph); |
119 | CASE_N(pw); |
120 | CASE_N(dd); |
121 | CASE_N(dh); |
122 | CASE_N(dw); |
123 | if (*s == 'n') { |
124 | d.name = s + 1; |
125 | break; |
126 | } |
127 | if (*s == '_') ++s; |
128 | if (!ok) return FAIL; |
129 | } |
130 | #undef CASE_NN |
131 | #undef CASE_N |
132 | |
133 | if (d.has_groups && d.g <= 0) return FAIL; |
134 | if (d.ic == 0 || d.oc == 0) return FAIL; |
135 | if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL; |
136 | |
137 | auto compute_out |
138 | = [](int64_t i, int64_t k, int64_t s, int64_t p, int64_t d) { |
139 | return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1; |
140 | }; |
141 | auto compute_pad |
142 | = [](int64_t o, int64_t i, int64_t k, int64_t s, int64_t d) { |
143 | return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2; |
144 | }; |
145 | |
146 | const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1; |
147 | const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1; |
148 | const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1; |
149 | |
150 | if (!no_d) { |
151 | if (!d.id || !d.kd) return FAIL; |
152 | if (!d.od) { |
153 | if (d.pd < 0) d.pd = 0; |
154 | d.od = compute_out(d.id, d.kd, d.sd, d.pd, d.dd); |
155 | if (d.od <= 0) return FAIL; |
156 | } else if (d.pd < 0) |
157 | d.pd = compute_pad(d.od, d.id, d.kd, d.sd, d.dd); |
158 | } |
159 | |
160 | if (!no_h) { |
161 | if (!d.ih || !d.kh) return FAIL; |
162 | if (!d.oh) { |
163 | if (d.ph < 0) d.ph = 0; |
164 | d.oh = compute_out(d.ih, d.kh, d.sh, d.ph, d.dh); |
165 | if (d.oh <= 0) return FAIL; |
166 | } else if (d.ph < 0) |
167 | d.ph = compute_pad(d.oh, d.ih, d.kh, d.sh, d.dh); |
168 | } |
169 | |
170 | if (!no_w) { |
171 | if (!d.iw || !d.kw) return FAIL; |
172 | if (!d.ow) { |
173 | if (d.pw < 0) d.pw = 0; |
174 | d.ow = compute_out(d.iw, d.kw, d.sw, d.pw, d.dw); |
175 | if (d.ow <= 0) return FAIL; |
176 | } else if (d.pw < 0) |
177 | d.pw = compute_pad(d.ow, d.iw, d.kw, d.sw, d.dw); |
178 | } |
179 | |
180 | if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
181 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
182 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true) |
183 | != OK) |
184 | return FAIL; |
185 | |
186 | d.init_pad_r(); |
187 | *desc = d; |
188 | |
189 | return OK; |
190 | } |
191 | |
192 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
193 | bool print_d = true, print_h = true, print_w = true; |
194 | print_dhw(print_d, print_h, print_w, d.ndims, |
195 | {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
196 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
197 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}); |
198 | |
199 | auto print_spatial |
200 | = [&](const char *d_str, int64_t d_val, const char *h_str, |
201 | int64_t h_val, const char *w_str, int64_t w_val) { |
202 | if (print_d) s << d_str << d_val; |
203 | if (print_h) s << h_str << h_val; |
204 | if (print_w) s << w_str << w_val; |
205 | }; |
206 | |
207 | if (canonical || d.has_groups) s << "g" << d.g; |
208 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
209 | s << "ic" << d.ic; |
210 | print_spatial("id" , d.id, "ih" , d.ih, "iw" , d.iw); |
211 | s << "oc" << d.oc; |
212 | print_spatial("od" , d.od, "oh" , d.oh, "ow" , d.ow); |
213 | print_spatial("kd" , d.kd, "kh" , d.kh, "kw" , d.kw); |
214 | |
215 | if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1) |
216 | print_spatial("sd" , d.sd, "sh" , d.sh, "sw" , d.sw); |
217 | |
218 | print_spatial("pd" , d.pd, "ph" , d.ph, "pw" , d.pw); |
219 | |
220 | if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0) |
221 | print_spatial("dd" , d.dd, "dh" , d.dh, "dw" , d.dw); |
222 | |
223 | if (!d.name.empty()) s << "n" << d.name; |
224 | |
225 | return s; |
226 | } |
227 | |
228 | dims_t desc_t::src_dims() const { |
229 | dims_t src_dims {mb, ic, id, ih, iw}; |
230 | for (int d = 0; d < 5 - ndims; ++d) { |
231 | src_dims.erase(src_dims.begin() + 2); |
232 | } |
233 | |
234 | return src_dims; |
235 | } |
236 | |
237 | dims_t desc_t::wei_dims() const { |
238 | dims_t wei_dims {g, oc / g, ic / g, kd, kh, kw}; |
239 | if (!has_groups) { wei_dims.erase(wei_dims.begin()); } |
240 | for (int d = 0; d < 5 - ndims; ++d) { |
241 | wei_dims.erase(wei_dims.begin() + 2 + has_groups); |
242 | } |
243 | |
244 | return wei_dims; |
245 | } |
246 | |
247 | dims_t desc_t::bia_dims() const { |
248 | dims_t bia_dims {oc}; |
249 | return bia_dims; |
250 | } |
251 | |
252 | dims_t desc_t::dst_dims() const { |
253 | dims_t dst_dims {mb, oc, od, oh, ow}; |
254 | for (int d = 0; d < 5 - ndims; ++d) { |
255 | dst_dims.erase(dst_dims.begin() + 2); |
256 | } |
257 | |
258 | return dst_dims; |
259 | } |
260 | |
261 | dims_t desc_t::strides() const { |
262 | dims_t strides {sd, sh, sw}; |
263 | return dims_t(strides.begin() + (5 - ndims), strides.end()); |
264 | } |
265 | |
266 | dims_t desc_t::dilations() const { |
267 | dims_t dilations {dd, dh, dw}; |
268 | return dims_t(dilations.begin() + (5 - ndims), dilations.end()); |
269 | } |
270 | |
271 | dims_t desc_t::padding() const { |
272 | dims_t padding {pd, ph, pw}; |
273 | return dims_t(padding.begin() + (5 - ndims), padding.end()); |
274 | } |
275 | |
276 | dims_t desc_t::padding_r() const { |
277 | dims_t padding_r {pd_r, ph_r, pw_r}; |
278 | return dims_t(padding_r.begin() + (5 - ndims), padding_r.end()); |
279 | } |
280 | |
281 | int64_t desc_t::desc_nelems(int arg, int mask) const { |
282 | dims_t dims; |
283 | switch (arg) { |
284 | case DNNL_ARG_SRC: dims = src_dims(); break; |
285 | case DNNL_ARG_WEIGHTS: dims = wei_dims(); break; |
286 | case DNNL_ARG_DST: dims = dst_dims(); break; |
287 | default: assert(!"unsupported arg" ); |
288 | } |
289 | |
290 | int64_t nelems = 1; |
291 | for (int d = 0; d < ndims; d++) { |
292 | nelems *= (mask & (1 << d)) ? dims[d] : 1; |
293 | } |
294 | return nelems; |
295 | } |
296 | |
297 | void prb_t::count_ops() { |
298 | if (ops > 0) return; |
299 | |
300 | double sp_ops = 0; |
301 | for_(int64_t od = 0; od < this->od; ++od) |
302 | for_(int64_t oh = 0; oh < this->oh; ++oh) |
303 | for (int64_t ow = 0; ow < this->ow; ++ow) { |
304 | for (int64_t kd = 0; kd < this->kd; ++kd) { |
305 | const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1); |
306 | if (id < 0 || id >= this->id) continue; |
307 | for (int64_t kh = 0; kh < this->kh; ++kh) { |
308 | const int64_t ih |
309 | = oh * this->sh - this->ph + kh * (this->dh + 1); |
310 | if (ih < 0 || ih >= this->ih) continue; |
311 | for (int64_t kw = 0; kw < this->kw; ++kw) { |
312 | const int64_t iw |
313 | = ow * this->sw - this->pw + kw * (this->dw + 1); |
314 | if (iw < 0 || iw >= this->iw) continue; |
315 | sp_ops += 1; |
316 | } |
317 | } |
318 | } |
319 | } |
320 | |
321 | ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops; |
322 | } |
323 | |
324 | float *prb_t::generate_scales(int arg) const { |
325 | const auto &scales = attr.scales; |
326 | if (scales.is_def()) return nullptr; |
327 | |
328 | const auto &e = scales.get(arg); |
329 | if (e.policy == policy_t::COMMON) { |
330 | float *s = (float *)zmalloc(sizeof(float), 4); |
331 | SAFE_V(s != nullptr ? OK : FAIL); |
332 | s[0] = e.scale; |
333 | return s; |
334 | } |
335 | |
336 | assert(e.policy == policy_t::PER_OC); |
337 | auto mask = attr_t::get_default_mask(e.policy, arg); |
338 | if (arg == DNNL_ARG_WEIGHTS && has_groups) mask = (1 << mask) + 1; |
339 | int64_t s_nelems = desc_nelems(arg, mask); |
340 | |
341 | float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64); |
342 | SAFE_V(s != nullptr ? OK : FAIL); |
343 | |
344 | const float K = 32; |
345 | /* scale in [1/K .. K], with starting point at e.scale */ |
346 | float s_val[2] = {e.scale, e.scale / 2}; |
347 | for (int64_t i = 0; i < s_nelems; ++i) { |
348 | int64_t si = i % 2; // 0 -> left, 1 -> right |
349 | s[i] = s_val[si]; |
350 | if (si == 0) { |
351 | s_val[si] /= 2.; |
352 | // turn around to become ~K |
353 | if (s_val[si] < 1. / K) s_val[si] *= K * K; |
354 | } else { |
355 | s_val[si] *= 2.; |
356 | // turn around to become ~K |
357 | if (s_val[si] > K) s_val[si] /= K * K; |
358 | } |
359 | } |
360 | return s; |
361 | } |
362 | |
363 | int32_t *prb_t::generate_zero_points(int arg) const { |
364 | const auto &zp = attr.zero_points; |
365 | if (zp.is_def(arg)) return nullptr; |
366 | |
367 | const auto &e = zp.get(arg); |
368 | const auto mask = attr_t::get_default_mask(e.policy); |
369 | int64_t zp_nelems = desc_nelems(arg, mask); |
370 | |
371 | int32_t *ptr = (int32_t *)zmalloc(sizeof(int32_t) * zp_nelems, 64); |
372 | SAFE_V(ptr != nullptr ? OK : FAIL); |
373 | |
374 | for (int i = 0; i < zp_nelems; ++i) |
375 | ptr[i] = e.value + i % 3; |
376 | |
377 | return ptr; |
378 | } |
379 | |
380 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
381 | dump_global_params(s); |
382 | settings_t def; |
383 | |
384 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
385 | if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " " ; |
386 | if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " " ; |
387 | if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " " ; |
388 | if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " " ; |
389 | if (canonical || prb.alg != def.alg[0]) |
390 | s << "--alg=" << alg2str(prb.alg) << " " ; |
391 | |
392 | s << prb.attr; |
393 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
394 | s << "--ctx-init=" << prb.ctx_init << " " ; |
395 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
396 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
397 | |
398 | s << static_cast<const desc_t &>(prb); |
399 | |
400 | return s; |
401 | } |
402 | |
403 | } // namespace conv |
404 | |