1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnn_types.hpp"
18#include "dnnl_common.hpp"
19
20#include "deconv/deconv.hpp"
21
22namespace deconv {
23
24alg_t str2alg(const char *str) {
25#define CASE(_alg) \
26 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
27 CASE(DIRECT);
28 CASE(deconvolution_direct);
29 CASE(WINO);
30 CASE(deconvolution_wino);
31#undef CASE
32 assert(!"unknown algorithm");
33 return UNDEF;
34}
35
36const char *alg2str(alg_t alg) {
37 if (alg == DIRECT) return "direct";
38 if (alg == WINO) return "wino";
39 assert(!"unknown algorithm");
40 return "undef";
41}
42
43alg_t alg_kind2alg(dnnl_alg_kind_t alg) {
44 if (alg == dnnl_deconvolution_direct) return DIRECT;
45 if (alg == dnnl_deconvolution_winograd) return WINO;
46 assert(!"unknown algorithm");
47 return DIRECT;
48}
49
50int str2desc(desc_t *desc, const char *str) {
51 /* canonical form:
52 * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
53 *
54 * where X is number, S - string
55 * note: symbol `_` is ignored
56 *
57 * implicit rules:
58 * - if smaller dimensions are not specified => square or cubic form;
59 * - if output is undefined => compute output;
60 * - if padding is undefined => compute trivial padding;
61 */
62
63 desc_t d {0};
64 d.g = 1;
65 d.mb = 2;
66 d.sd = d.sh = d.sw = 1;
67 d.pd = d.ph = d.pw = -1;
68
69 const char *s = str;
70 assert(s);
71
72#define CASE_NN(prb, c) \
73 do { \
74 if (!strncmp(prb, s, strlen(prb))) { \
75 ok = 1; \
76 s += strlen(prb); \
77 char *end_s; \
78 d.c = strtol(s, &end_s, 10); \
79 s += (end_s - s); \
80 /* check any # groups, including one, works correctly */ \
81 if (!strncmp(prb, "g", 1)) d.has_groups = true; \
82 if (d.c < 0) return FAIL; \
83 /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
84 } \
85 } while (0)
86#define CASE_N(c) CASE_NN(#c, c)
87 while (*s) {
88 int ok = 0;
89 CASE_N(g);
90 CASE_N(mb);
91 CASE_N(ic);
92 CASE_N(id);
93 CASE_N(ih);
94 CASE_N(iw);
95 CASE_N(oc);
96 CASE_N(od);
97 CASE_N(oh);
98 CASE_N(ow);
99 CASE_N(kd);
100 CASE_N(kh);
101 CASE_N(kw);
102 CASE_N(sd);
103 CASE_N(sh);
104 CASE_N(sw);
105 CASE_N(pd);
106 CASE_N(ph);
107 CASE_N(pw);
108 CASE_N(dd);
109 CASE_N(dh);
110 CASE_N(dw);
111 if (*s == 'n') {
112 d.name = s + 1;
113 break;
114 }
115 if (*s == '_') ++s;
116 if (!ok) return FAIL;
117 }
118#undef CASE_NN
119#undef CASE_N
120
121 if (d.has_groups && d.g <= 0) return FAIL;
122 if (d.ic == 0 || d.oc == 0) return FAIL;
123 if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
124
125 auto compute_out
126 = [](int64_t i, int64_t k, int64_t s, int64_t p, int64_t d) {
127 return (i - 1) * s + (k - 1) * (d + 1) - 2 * p + 1;
128 };
129 auto compute_pad
130 = [](int64_t o, int64_t i, int64_t k, int64_t s, int64_t d) {
131 return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2;
132 };
133
134 const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
135 const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
136 const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
137
138 if (!no_d) {
139 if (!d.id || !d.kd) return FAIL;
140 if (!d.od) {
141 if (d.pd < 0) d.pd = 0;
142 d.od = compute_out(d.id, d.kd, d.sd, d.pd, d.dd);
143 if (d.od <= 0) return FAIL;
144 } else if (d.pd < 0)
145 d.pd = compute_pad(d.od, d.id, d.kd, d.sd, d.dd);
146 }
147
148 if (!no_h) {
149 if (!d.ih || !d.kh) return FAIL;
150 if (!d.oh) {
151 if (d.ph < 0) d.ph = 0;
152 d.oh = compute_out(d.ih, d.kh, d.sh, d.ph, d.dh);
153 if (d.oh <= 0) return FAIL;
154 } else if (d.ph < 0)
155 d.ph = compute_pad(d.oh, d.ih, d.kh, d.sh, d.dh);
156 }
157
158 if (!no_w) {
159 if (!d.iw || !d.kw) return FAIL;
160 if (!d.ow) {
161 if (d.pw < 0) d.pw = 0;
162 d.ow = compute_out(d.iw, d.kw, d.sw, d.pw, d.dw);
163 if (d.ow <= 0) return FAIL;
164 } else if (d.pw < 0)
165 d.pw = compute_pad(d.ow, d.iw, d.kw, d.sw, d.dw);
166 }
167
168 if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
169 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
170 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true)
171 != OK)
172 return FAIL;
173
174 d.init_pad_r();
175 *desc = d;
176
177 return OK;
178}
179
180std::ostream &operator<<(std::ostream &s, const desc_t &d) {
181 bool print_d = true, print_h = true, print_w = true;
182 print_dhw(print_d, print_h, print_w, d.ndims,
183 {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
184 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
185 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw});
186
187 auto print_spatial
188 = [&](const char *d_str, int64_t d_val, const char *h_str,
189 int64_t h_val, const char *w_str, int64_t w_val) {
190 if (print_d) s << d_str << d_val;
191 if (print_h) s << h_str << h_val;
192 if (print_w) s << w_str << w_val;
193 };
194
195 if (canonical || d.has_groups) s << "g" << d.g;
196 if (canonical || d.mb != 2) s << "mb" << d.mb;
197 s << "ic" << d.ic;
198 print_spatial("id", d.id, "ih", d.ih, "iw", d.iw);
199 s << "oc" << d.oc;
200 print_spatial("od", d.od, "oh", d.oh, "ow", d.ow);
201 print_spatial("kd", d.kd, "kh", d.kh, "kw", d.kw);
202
203 if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1)
204 print_spatial("sd", d.sd, "sh", d.sh, "sw", d.sw);
205
206 print_spatial("pd", d.pd, "ph", d.ph, "pw", d.pw);
207
208 if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0)
209 print_spatial("dd", d.dd, "dh", d.dh, "dw", d.dw);
210
211 if (!d.name.empty()) s << "n" << d.name;
212
213 return s;
214}
215
216dims_t desc_t::src_dims() const {
217 dims_t src_dims {mb, ic, id, ih, iw};
218 for (int d = 0; d < 5 - ndims; ++d) {
219 src_dims.erase(src_dims.begin() + 2);
220 }
221
222 return src_dims;
223}
224
225dims_t desc_t::wei_dims() const {
226 dims_t wei_dims {g, oc / g, ic / g, kd, kh, kw};
227 if (!has_groups) { wei_dims.erase(wei_dims.begin()); }
228 for (int d = 0; d < 5 - ndims; ++d) {
229 wei_dims.erase(wei_dims.begin() + 2 + has_groups);
230 }
231
232 return wei_dims;
233}
234
235dims_t desc_t::bia_dims() const {
236 dims_t bia_dims {oc};
237 return bia_dims;
238}
239
240dims_t desc_t::dst_dims() const {
241 dims_t dst_dims {mb, oc, od, oh, ow};
242 for (int d = 0; d < 5 - ndims; ++d) {
243 dst_dims.erase(dst_dims.begin() + 2);
244 }
245
246 return dst_dims;
247}
248
249dims_t desc_t::strides() const {
250 dims_t strides {sd, sh, sw};
251 return dims_t(strides.begin() + (5 - ndims), strides.end());
252}
253
254dims_t desc_t::dilations() const {
255 dims_t dilations {dd, dh, dw};
256 return dims_t(dilations.begin() + (5 - ndims), dilations.end());
257}
258
259dims_t desc_t::padding() const {
260 dims_t padding {pd, ph, pw};
261 return dims_t(padding.begin() + (5 - ndims), padding.end());
262}
263
264dims_t desc_t::padding_r() const {
265 dims_t padding_r {pd_r, ph_r, pw_r};
266 return dims_t(padding_r.begin() + (5 - ndims), padding_r.end());
267}
268
269int64_t desc_t::desc_nelems(int arg, int mask) const {
270 std::vector<int64_t> src {mb, ic, id, ih, iw};
271 std::vector<int64_t> wei {g, oc, ic, kd, kh, kw};
272 std::vector<int64_t> dst {mb, oc, od, oh, ow};
273 std::vector<int64_t> dummy;
274
275 if (!has_groups) wei.erase(wei.begin());
276
277 for (int d = 0; d < 5 - ndims; d++) {
278 src.erase(src.begin() + 2);
279 wei.erase(wei.begin() + 2);
280 dst.erase(dst.begin() + 2);
281 }
282
283 std::vector<int64_t> &dims = dummy;
284 switch (arg) {
285 case DNNL_ARG_SRC: dims = src; break;
286 case DNNL_ARG_WEIGHTS: dims = wei; break;
287 case DNNL_ARG_DST: dims = dst; break;
288 default: assert(!"unsupported arg");
289 }
290
291 int64_t nelems = 1;
292 for (int d = 0; d < ndims; d++) {
293 nelems *= (mask & (1 << d)) ? dims[d] : 1;
294 }
295 return nelems;
296}
297
298void prb_t::count_ops() {
299 if (ops > 0) return;
300
301 int64_t od_t = this->id;
302 int64_t oh_t = this->ih;
303 int64_t ow_t = this->iw;
304 int64_t id_t = this->od;
305 int64_t ih_t = this->oh;
306 int64_t iw_t = this->ow;
307 double sp_ops = 0;
308 for_(int64_t od = 0; od < od_t; ++od)
309 for_(int64_t oh = 0; oh < oh_t; ++oh)
310 for (int64_t ow = 0; ow < ow_t; ++ow) {
311 for (int64_t kd = 0; kd < this->kd; ++kd) {
312 const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1);
313 if (id < 0 || id >= id_t) continue;
314 for (int64_t kh = 0; kh < this->kh; ++kh) {
315 const int64_t ih
316 = oh * this->sh - this->ph + kh * (this->dh + 1);
317 if (ih < 0 || ih >= ih_t) continue;
318 for (int64_t kw = 0; kw < this->kw; ++kw) {
319 const int64_t iw
320 = ow * this->sw - this->pw + kw * (this->dw + 1);
321 if (iw < 0 || iw >= iw_t) continue;
322 sp_ops += 1;
323 }
324 }
325 }
326 }
327
328 ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
329}
330
331float *prb_t::generate_scales(int arg) {
332 const auto &scales = attr.scales;
333 if (scales.is_def()) return nullptr;
334
335 const auto &e = scales.get(arg);
336 if (e.policy == policy_t::COMMON) {
337 float *s = (float *)zmalloc(sizeof(float), 4);
338 SAFE_V(s != nullptr ? OK : FAIL);
339 s[0] = e.scale;
340 return s;
341 }
342
343 assert(e.policy == policy_t::PER_OC);
344 auto mask = attr_t::get_default_mask(e.policy, arg);
345 if (arg == DNNL_ARG_WEIGHTS && has_groups) mask = (1 << mask) + 1;
346 int64_t s_nelems = desc_nelems(arg, mask);
347
348 float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64);
349 SAFE_V(s != nullptr ? OK : FAIL);
350
351 const float K = 32;
352 /* scale in [1/K .. K], with starting point at e.scale */
353 float s_val[2] = {e.scale, e.scale / 2};
354 for (int64_t i = 0; i < s_nelems; ++i) {
355 int64_t si = i % 2; // 0 -> left, 1 -> right
356 s[i] = s_val[si];
357 if (si == 0) {
358 s_val[si] /= 2.;
359 // turn around to become ~K
360 if (s_val[si] < 1. / K) s_val[si] *= K * K;
361 } else {
362 s_val[si] *= 2.;
363 // turn around to become ~K
364 if (s_val[si] > K) s_val[si] /= K * K;
365 }
366 }
367 return s;
368}
369
370int32_t *prb_t::generate_zero_points(int arg) {
371 const auto &zp = attr.zero_points;
372 if (zp.is_def(arg)) return nullptr;
373
374 const auto &e = zp.get(arg);
375 const auto mask = attr_t::get_default_mask(e.policy);
376 int64_t zp_nelems = desc_nelems(arg, mask);
377
378 int32_t *ptr = (int32_t *)zmalloc(sizeof(int32_t) * zp_nelems, 64);
379 SAFE_V(ptr != nullptr ? OK : FAIL);
380
381 for (int i = 0; i < zp_nelems; ++i)
382 ptr[i] = e.value + i % 3;
383
384 return ptr;
385}
386
387std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
388 dump_global_params(s);
389 settings_t def;
390
391 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
392 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
393 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
394 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
395 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
396 if (canonical || prb.alg != def.alg[0])
397 s << "--alg=" << alg2str(prb.alg) << " ";
398
399 s << prb.attr;
400 if (canonical || prb.ctx_init != def.ctx_init[0])
401 s << "--ctx-init=" << prb.ctx_init << " ";
402 if (canonical || prb.ctx_exe != def.ctx_exe[0])
403 s << "--ctx-exe=" << prb.ctx_exe << " ";
404
405 s << static_cast<const desc_t &>(prb);
406
407 return s;
408}
409
410} // namespace deconv
411