1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnn_types.hpp" |
18 | #include "dnnl_common.hpp" |
19 | |
20 | #include "deconv/deconv.hpp" |
21 | |
22 | namespace deconv { |
23 | |
24 | alg_t str2alg(const char *str) { |
25 | #define CASE(_alg) \ |
26 | if (!strcasecmp(STRINGIFY(_alg), str)) return _alg |
27 | CASE(DIRECT); |
28 | CASE(deconvolution_direct); |
29 | CASE(WINO); |
30 | CASE(deconvolution_wino); |
31 | #undef CASE |
32 | assert(!"unknown algorithm" ); |
33 | return UNDEF; |
34 | } |
35 | |
36 | const char *alg2str(alg_t alg) { |
37 | if (alg == DIRECT) return "direct" ; |
38 | if (alg == WINO) return "wino" ; |
39 | assert(!"unknown algorithm" ); |
40 | return "undef" ; |
41 | } |
42 | |
43 | alg_t alg_kind2alg(dnnl_alg_kind_t alg) { |
44 | if (alg == dnnl_deconvolution_direct) return DIRECT; |
45 | if (alg == dnnl_deconvolution_winograd) return WINO; |
46 | assert(!"unknown algorithm" ); |
47 | return DIRECT; |
48 | } |
49 | |
50 | int str2desc(desc_t *desc, const char *str) { |
51 | /* canonical form: |
52 | * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS |
53 | * |
54 | * where X is number, S - string |
55 | * note: symbol `_` is ignored |
56 | * |
57 | * implicit rules: |
58 | * - if smaller dimensions are not specified => square or cubic form; |
59 | * - if output is undefined => compute output; |
60 | * - if padding is undefined => compute trivial padding; |
61 | */ |
62 | |
63 | desc_t d {0}; |
64 | d.g = 1; |
65 | d.mb = 2; |
66 | d.sd = d.sh = d.sw = 1; |
67 | d.pd = d.ph = d.pw = -1; |
68 | |
69 | const char *s = str; |
70 | assert(s); |
71 | |
72 | #define CASE_NN(prb, c) \ |
73 | do { \ |
74 | if (!strncmp(prb, s, strlen(prb))) { \ |
75 | ok = 1; \ |
76 | s += strlen(prb); \ |
77 | char *end_s; \ |
78 | d.c = strtol(s, &end_s, 10); \ |
79 | s += (end_s - s); \ |
80 | /* check any # groups, including one, works correctly */ \ |
81 | if (!strncmp(prb, "g", 1)) d.has_groups = true; \ |
82 | if (d.c < 0) return FAIL; \ |
83 | /* printf("@@@debug: %s: %d\n", prb, d. c); */ \ |
84 | } \ |
85 | } while (0) |
86 | #define CASE_N(c) CASE_NN(#c, c) |
87 | while (*s) { |
88 | int ok = 0; |
89 | CASE_N(g); |
90 | CASE_N(mb); |
91 | CASE_N(ic); |
92 | CASE_N(id); |
93 | CASE_N(ih); |
94 | CASE_N(iw); |
95 | CASE_N(oc); |
96 | CASE_N(od); |
97 | CASE_N(oh); |
98 | CASE_N(ow); |
99 | CASE_N(kd); |
100 | CASE_N(kh); |
101 | CASE_N(kw); |
102 | CASE_N(sd); |
103 | CASE_N(sh); |
104 | CASE_N(sw); |
105 | CASE_N(pd); |
106 | CASE_N(ph); |
107 | CASE_N(pw); |
108 | CASE_N(dd); |
109 | CASE_N(dh); |
110 | CASE_N(dw); |
111 | if (*s == 'n') { |
112 | d.name = s + 1; |
113 | break; |
114 | } |
115 | if (*s == '_') ++s; |
116 | if (!ok) return FAIL; |
117 | } |
118 | #undef CASE_NN |
119 | #undef CASE_N |
120 | |
121 | if (d.has_groups && d.g <= 0) return FAIL; |
122 | if (d.ic == 0 || d.oc == 0) return FAIL; |
123 | if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL; |
124 | |
125 | auto compute_out |
126 | = [](int64_t i, int64_t k, int64_t s, int64_t p, int64_t d) { |
127 | return (i - 1) * s + (k - 1) * (d + 1) - 2 * p + 1; |
128 | }; |
129 | auto compute_pad |
130 | = [](int64_t o, int64_t i, int64_t k, int64_t s, int64_t d) { |
131 | return ((i - 1) * s - o + ((k - 1) * (d + 1) + 1)) / 2; |
132 | }; |
133 | |
134 | const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1; |
135 | const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1; |
136 | const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1; |
137 | |
138 | if (!no_d) { |
139 | if (!d.id || !d.kd) return FAIL; |
140 | if (!d.od) { |
141 | if (d.pd < 0) d.pd = 0; |
142 | d.od = compute_out(d.id, d.kd, d.sd, d.pd, d.dd); |
143 | if (d.od <= 0) return FAIL; |
144 | } else if (d.pd < 0) |
145 | d.pd = compute_pad(d.od, d.id, d.kd, d.sd, d.dd); |
146 | } |
147 | |
148 | if (!no_h) { |
149 | if (!d.ih || !d.kh) return FAIL; |
150 | if (!d.oh) { |
151 | if (d.ph < 0) d.ph = 0; |
152 | d.oh = compute_out(d.ih, d.kh, d.sh, d.ph, d.dh); |
153 | if (d.oh <= 0) return FAIL; |
154 | } else if (d.ph < 0) |
155 | d.ph = compute_pad(d.oh, d.ih, d.kh, d.sh, d.dh); |
156 | } |
157 | |
158 | if (!no_w) { |
159 | if (!d.iw || !d.kw) return FAIL; |
160 | if (!d.ow) { |
161 | if (d.pw < 0) d.pw = 0; |
162 | d.ow = compute_out(d.iw, d.kw, d.sw, d.pw, d.dw); |
163 | if (d.ow <= 0) return FAIL; |
164 | } else if (d.pw < 0) |
165 | d.pw = compute_pad(d.ow, d.iw, d.kw, d.sw, d.dw); |
166 | } |
167 | |
168 | if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
169 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
170 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true) |
171 | != OK) |
172 | return FAIL; |
173 | |
174 | d.init_pad_r(); |
175 | *desc = d; |
176 | |
177 | return OK; |
178 | } |
179 | |
180 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
181 | bool print_d = true, print_h = true, print_w = true; |
182 | print_dhw(print_d, print_h, print_w, d.ndims, |
183 | {d.od, d.id, d.kd, d.sd, d.pd, d.dd}, |
184 | {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh}, |
185 | {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}); |
186 | |
187 | auto print_spatial |
188 | = [&](const char *d_str, int64_t d_val, const char *h_str, |
189 | int64_t h_val, const char *w_str, int64_t w_val) { |
190 | if (print_d) s << d_str << d_val; |
191 | if (print_h) s << h_str << h_val; |
192 | if (print_w) s << w_str << w_val; |
193 | }; |
194 | |
195 | if (canonical || d.has_groups) s << "g" << d.g; |
196 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
197 | s << "ic" << d.ic; |
198 | print_spatial("id" , d.id, "ih" , d.ih, "iw" , d.iw); |
199 | s << "oc" << d.oc; |
200 | print_spatial("od" , d.od, "oh" , d.oh, "ow" , d.ow); |
201 | print_spatial("kd" , d.kd, "kh" , d.kh, "kw" , d.kw); |
202 | |
203 | if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1) |
204 | print_spatial("sd" , d.sd, "sh" , d.sh, "sw" , d.sw); |
205 | |
206 | print_spatial("pd" , d.pd, "ph" , d.ph, "pw" , d.pw); |
207 | |
208 | if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0) |
209 | print_spatial("dd" , d.dd, "dh" , d.dh, "dw" , d.dw); |
210 | |
211 | if (!d.name.empty()) s << "n" << d.name; |
212 | |
213 | return s; |
214 | } |
215 | |
216 | dims_t desc_t::src_dims() const { |
217 | dims_t src_dims {mb, ic, id, ih, iw}; |
218 | for (int d = 0; d < 5 - ndims; ++d) { |
219 | src_dims.erase(src_dims.begin() + 2); |
220 | } |
221 | |
222 | return src_dims; |
223 | } |
224 | |
225 | dims_t desc_t::wei_dims() const { |
226 | dims_t wei_dims {g, oc / g, ic / g, kd, kh, kw}; |
227 | if (!has_groups) { wei_dims.erase(wei_dims.begin()); } |
228 | for (int d = 0; d < 5 - ndims; ++d) { |
229 | wei_dims.erase(wei_dims.begin() + 2 + has_groups); |
230 | } |
231 | |
232 | return wei_dims; |
233 | } |
234 | |
235 | dims_t desc_t::bia_dims() const { |
236 | dims_t bia_dims {oc}; |
237 | return bia_dims; |
238 | } |
239 | |
240 | dims_t desc_t::dst_dims() const { |
241 | dims_t dst_dims {mb, oc, od, oh, ow}; |
242 | for (int d = 0; d < 5 - ndims; ++d) { |
243 | dst_dims.erase(dst_dims.begin() + 2); |
244 | } |
245 | |
246 | return dst_dims; |
247 | } |
248 | |
249 | dims_t desc_t::strides() const { |
250 | dims_t strides {sd, sh, sw}; |
251 | return dims_t(strides.begin() + (5 - ndims), strides.end()); |
252 | } |
253 | |
254 | dims_t desc_t::dilations() const { |
255 | dims_t dilations {dd, dh, dw}; |
256 | return dims_t(dilations.begin() + (5 - ndims), dilations.end()); |
257 | } |
258 | |
259 | dims_t desc_t::padding() const { |
260 | dims_t padding {pd, ph, pw}; |
261 | return dims_t(padding.begin() + (5 - ndims), padding.end()); |
262 | } |
263 | |
264 | dims_t desc_t::padding_r() const { |
265 | dims_t padding_r {pd_r, ph_r, pw_r}; |
266 | return dims_t(padding_r.begin() + (5 - ndims), padding_r.end()); |
267 | } |
268 | |
269 | int64_t desc_t::desc_nelems(int arg, int mask) const { |
270 | std::vector<int64_t> src {mb, ic, id, ih, iw}; |
271 | std::vector<int64_t> wei {g, oc, ic, kd, kh, kw}; |
272 | std::vector<int64_t> dst {mb, oc, od, oh, ow}; |
273 | std::vector<int64_t> dummy; |
274 | |
275 | if (!has_groups) wei.erase(wei.begin()); |
276 | |
277 | for (int d = 0; d < 5 - ndims; d++) { |
278 | src.erase(src.begin() + 2); |
279 | wei.erase(wei.begin() + 2); |
280 | dst.erase(dst.begin() + 2); |
281 | } |
282 | |
283 | std::vector<int64_t> &dims = dummy; |
284 | switch (arg) { |
285 | case DNNL_ARG_SRC: dims = src; break; |
286 | case DNNL_ARG_WEIGHTS: dims = wei; break; |
287 | case DNNL_ARG_DST: dims = dst; break; |
288 | default: assert(!"unsupported arg" ); |
289 | } |
290 | |
291 | int64_t nelems = 1; |
292 | for (int d = 0; d < ndims; d++) { |
293 | nelems *= (mask & (1 << d)) ? dims[d] : 1; |
294 | } |
295 | return nelems; |
296 | } |
297 | |
298 | void prb_t::count_ops() { |
299 | if (ops > 0) return; |
300 | |
301 | int64_t od_t = this->id; |
302 | int64_t oh_t = this->ih; |
303 | int64_t ow_t = this->iw; |
304 | int64_t id_t = this->od; |
305 | int64_t ih_t = this->oh; |
306 | int64_t iw_t = this->ow; |
307 | double sp_ops = 0; |
308 | for_(int64_t od = 0; od < od_t; ++od) |
309 | for_(int64_t oh = 0; oh < oh_t; ++oh) |
310 | for (int64_t ow = 0; ow < ow_t; ++ow) { |
311 | for (int64_t kd = 0; kd < this->kd; ++kd) { |
312 | const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1); |
313 | if (id < 0 || id >= id_t) continue; |
314 | for (int64_t kh = 0; kh < this->kh; ++kh) { |
315 | const int64_t ih |
316 | = oh * this->sh - this->ph + kh * (this->dh + 1); |
317 | if (ih < 0 || ih >= ih_t) continue; |
318 | for (int64_t kw = 0; kw < this->kw; ++kw) { |
319 | const int64_t iw |
320 | = ow * this->sw - this->pw + kw * (this->dw + 1); |
321 | if (iw < 0 || iw >= iw_t) continue; |
322 | sp_ops += 1; |
323 | } |
324 | } |
325 | } |
326 | } |
327 | |
328 | ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops; |
329 | } |
330 | |
331 | float *prb_t::generate_scales(int arg) { |
332 | const auto &scales = attr.scales; |
333 | if (scales.is_def()) return nullptr; |
334 | |
335 | const auto &e = scales.get(arg); |
336 | if (e.policy == policy_t::COMMON) { |
337 | float *s = (float *)zmalloc(sizeof(float), 4); |
338 | SAFE_V(s != nullptr ? OK : FAIL); |
339 | s[0] = e.scale; |
340 | return s; |
341 | } |
342 | |
343 | assert(e.policy == policy_t::PER_OC); |
344 | auto mask = attr_t::get_default_mask(e.policy, arg); |
345 | if (arg == DNNL_ARG_WEIGHTS && has_groups) mask = (1 << mask) + 1; |
346 | int64_t s_nelems = desc_nelems(arg, mask); |
347 | |
348 | float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64); |
349 | SAFE_V(s != nullptr ? OK : FAIL); |
350 | |
351 | const float K = 32; |
352 | /* scale in [1/K .. K], with starting point at e.scale */ |
353 | float s_val[2] = {e.scale, e.scale / 2}; |
354 | for (int64_t i = 0; i < s_nelems; ++i) { |
355 | int64_t si = i % 2; // 0 -> left, 1 -> right |
356 | s[i] = s_val[si]; |
357 | if (si == 0) { |
358 | s_val[si] /= 2.; |
359 | // turn around to become ~K |
360 | if (s_val[si] < 1. / K) s_val[si] *= K * K; |
361 | } else { |
362 | s_val[si] *= 2.; |
363 | // turn around to become ~K |
364 | if (s_val[si] > K) s_val[si] /= K * K; |
365 | } |
366 | } |
367 | return s; |
368 | } |
369 | |
370 | int32_t *prb_t::generate_zero_points(int arg) { |
371 | const auto &zp = attr.zero_points; |
372 | if (zp.is_def(arg)) return nullptr; |
373 | |
374 | const auto &e = zp.get(arg); |
375 | const auto mask = attr_t::get_default_mask(e.policy); |
376 | int64_t zp_nelems = desc_nelems(arg, mask); |
377 | |
378 | int32_t *ptr = (int32_t *)zmalloc(sizeof(int32_t) * zp_nelems, 64); |
379 | SAFE_V(ptr != nullptr ? OK : FAIL); |
380 | |
381 | for (int i = 0; i < zp_nelems; ++i) |
382 | ptr[i] = e.value + i % 3; |
383 | |
384 | return ptr; |
385 | } |
386 | |
387 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
388 | dump_global_params(s); |
389 | settings_t def; |
390 | |
391 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
392 | if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " " ; |
393 | if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " " ; |
394 | if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " " ; |
395 | if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " " ; |
396 | if (canonical || prb.alg != def.alg[0]) |
397 | s << "--alg=" << alg2str(prb.alg) << " " ; |
398 | |
399 | s << prb.attr; |
400 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
401 | s << "--ctx-init=" << prb.ctx_init << " " ; |
402 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
403 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
404 | |
405 | s << static_cast<const desc_t &>(prb); |
406 | |
407 | return s; |
408 | } |
409 | |
410 | } // namespace deconv |
411 | |