1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <stdio.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "oneapi/dnnl/dnnl.h"
24
25#include "dnn_types.hpp"
26#include "dnnl_common.hpp"
27
28#include "conv/conv.hpp"
29
30namespace conv {
31
32alg_t str2alg(const char *str) {
33#define CASE(_alg) \
34 if (!strcasecmp(STRINGIFY(_alg), str)) return _alg
35 CASE(AUTO);
36 CASE(convolution_auto);
37 CASE(DIRECT);
38 CASE(convolution_direct);
39 CASE(WINO);
40 CASE(convolution_wino);
41#undef CASE
42 assert(!"unknown algorithm");
43 return UNDEF;
44}
45
46const char *alg2str(alg_t alg) {
47 if (alg == AUTO) return "auto";
48 if (alg == DIRECT) return "direct";
49 if (alg == WINO) return "wino";
50 assert(!"unknown algorithm");
51 return "undef";
52}
53
54alg_t alg_kind2alg(dnnl_alg_kind_t alg) {
55 if (alg == dnnl_convolution_auto) return AUTO;
56 if (alg == dnnl_convolution_direct) return DIRECT;
57 if (alg == dnnl_convolution_winograd) return WINO;
58 assert(!"unknown algorithm");
59 return DIRECT;
60}
61
62int str2desc(desc_t *desc, const char *str) {
63 /* canonical form:
64 * gXmbX_icXidXihXiwX_ocXodXohXowX_kdXkhXkwX_sdXshXswX_pdXphXpwX_ddXdhXdwXnS
65 *
66 * where X is number, S - string
67 * note: symbol `_` is ignored
68 *
69 * implicit rules:
70 * - if smaller dimensions are not specified => square or cubic form;
71 * - if output is undefined => compute output;
72 * - if padding is undefined => compute trivial padding;
73 */
74
75 desc_t d {0};
76 d.g = 1;
77 d.mb = 2;
78 d.sd = d.sh = d.sw = 1;
79 d.pd = d.ph = d.pw = -1;
80
81 const char *s = str;
82 assert(s);
83
84#define CASE_NN(prb, c) \
85 do { \
86 if (!strncmp(prb, s, strlen(prb))) { \
87 ok = 1; \
88 s += strlen(prb); \
89 char *end_s; \
90 d.c = strtol(s, &end_s, 10); \
91 s += (end_s - s); \
92 /* check any # groups, including one, works correctly */ \
93 if (!strncmp(prb, "g", 1)) d.has_groups = true; \
94 if (d.c < 0) return FAIL; \
95 /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
96 } \
97 } while (0)
98#define CASE_N(c) CASE_NN(#c, c)
99 while (*s) {
100 int ok = 0;
101 CASE_N(g);
102 CASE_N(mb);
103 CASE_N(ic);
104 CASE_N(id);
105 CASE_N(ih);
106 CASE_N(iw);
107 CASE_N(oc);
108 CASE_N(od);
109 CASE_N(oh);
110 CASE_N(ow);
111 CASE_N(kd);
112 CASE_N(kh);
113 CASE_N(kw);
114 CASE_N(sd);
115 CASE_N(sh);
116 CASE_N(sw);
117 CASE_N(pd);
118 CASE_N(ph);
119 CASE_N(pw);
120 CASE_N(dd);
121 CASE_N(dh);
122 CASE_N(dw);
123 if (*s == 'n') {
124 d.name = s + 1;
125 break;
126 }
127 if (*s == '_') ++s;
128 if (!ok) return FAIL;
129 }
130#undef CASE_NN
131#undef CASE_N
132
133 if (d.has_groups && d.g <= 0) return FAIL;
134 if (d.ic == 0 || d.oc == 0) return FAIL;
135 if (d.sd <= 0 || d.sh <= 0 || d.sw <= 0) return FAIL;
136
137 auto compute_out
138 = [](int64_t i, int64_t k, int64_t s, int64_t p, int64_t d) {
139 return (i - ((k - 1) * (d + 1) + 1) + 2 * p) / s + 1;
140 };
141 auto compute_pad
142 = [](int64_t o, int64_t i, int64_t k, int64_t s, int64_t d) {
143 return ((o - 1) * s - i + ((k - 1) * (d + 1) + 1)) / 2;
144 };
145
146 const bool no_d = (d.id | d.kd | d.od | d.dd) == 0 && d.sd == 1 && d.pd < 1;
147 const bool no_h = (d.ih | d.kh | d.oh | d.dh) == 0 && d.sh == 1 && d.ph < 1;
148 const bool no_w = (d.iw | d.kw | d.ow | d.dw) == 0 && d.sw == 1 && d.pw < 1;
149
150 if (!no_d) {
151 if (!d.id || !d.kd) return FAIL;
152 if (!d.od) {
153 if (d.pd < 0) d.pd = 0;
154 d.od = compute_out(d.id, d.kd, d.sd, d.pd, d.dd);
155 if (d.od <= 0) return FAIL;
156 } else if (d.pd < 0)
157 d.pd = compute_pad(d.od, d.id, d.kd, d.sd, d.dd);
158 }
159
160 if (!no_h) {
161 if (!d.ih || !d.kh) return FAIL;
162 if (!d.oh) {
163 if (d.ph < 0) d.ph = 0;
164 d.oh = compute_out(d.ih, d.kh, d.sh, d.ph, d.dh);
165 if (d.oh <= 0) return FAIL;
166 } else if (d.ph < 0)
167 d.ph = compute_pad(d.oh, d.ih, d.kh, d.sh, d.dh);
168 }
169
170 if (!no_w) {
171 if (!d.iw || !d.kw) return FAIL;
172 if (!d.ow) {
173 if (d.pw < 0) d.pw = 0;
174 d.ow = compute_out(d.iw, d.kw, d.sw, d.pw, d.dw);
175 if (d.ow <= 0) return FAIL;
176 } else if (d.pw < 0)
177 d.pw = compute_pad(d.ow, d.iw, d.kw, d.sw, d.dw);
178 }
179
180 if (sanitize_desc(d.ndims, {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
181 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
182 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw}, {1, 1, 1, 1, 0, 0}, true)
183 != OK)
184 return FAIL;
185
186 d.init_pad_r();
187 *desc = d;
188
189 return OK;
190}
191
192std::ostream &operator<<(std::ostream &s, const desc_t &d) {
193 bool print_d = true, print_h = true, print_w = true;
194 print_dhw(print_d, print_h, print_w, d.ndims,
195 {d.od, d.id, d.kd, d.sd, d.pd, d.dd},
196 {d.oh, d.ih, d.kh, d.sh, d.ph, d.dh},
197 {d.ow, d.iw, d.kw, d.sw, d.pw, d.dw});
198
199 auto print_spatial
200 = [&](const char *d_str, int64_t d_val, const char *h_str,
201 int64_t h_val, const char *w_str, int64_t w_val) {
202 if (print_d) s << d_str << d_val;
203 if (print_h) s << h_str << h_val;
204 if (print_w) s << w_str << w_val;
205 };
206
207 if (canonical || d.has_groups) s << "g" << d.g;
208 if (canonical || d.mb != 2) s << "mb" << d.mb;
209 s << "ic" << d.ic;
210 print_spatial("id", d.id, "ih", d.ih, "iw", d.iw);
211 s << "oc" << d.oc;
212 print_spatial("od", d.od, "oh", d.oh, "ow", d.ow);
213 print_spatial("kd", d.kd, "kh", d.kh, "kw", d.kw);
214
215 if (canonical || d.sh != 1 || d.sw != 1 || d.sd != 1)
216 print_spatial("sd", d.sd, "sh", d.sh, "sw", d.sw);
217
218 print_spatial("pd", d.pd, "ph", d.ph, "pw", d.pw);
219
220 if (canonical || d.dh != 0 || d.dw != 0 || d.dd != 0)
221 print_spatial("dd", d.dd, "dh", d.dh, "dw", d.dw);
222
223 if (!d.name.empty()) s << "n" << d.name;
224
225 return s;
226}
227
228dims_t desc_t::src_dims() const {
229 dims_t src_dims {mb, ic, id, ih, iw};
230 for (int d = 0; d < 5 - ndims; ++d) {
231 src_dims.erase(src_dims.begin() + 2);
232 }
233
234 return src_dims;
235}
236
237dims_t desc_t::wei_dims() const {
238 dims_t wei_dims {g, oc / g, ic / g, kd, kh, kw};
239 if (!has_groups) { wei_dims.erase(wei_dims.begin()); }
240 for (int d = 0; d < 5 - ndims; ++d) {
241 wei_dims.erase(wei_dims.begin() + 2 + has_groups);
242 }
243
244 return wei_dims;
245}
246
247dims_t desc_t::bia_dims() const {
248 dims_t bia_dims {oc};
249 return bia_dims;
250}
251
252dims_t desc_t::dst_dims() const {
253 dims_t dst_dims {mb, oc, od, oh, ow};
254 for (int d = 0; d < 5 - ndims; ++d) {
255 dst_dims.erase(dst_dims.begin() + 2);
256 }
257
258 return dst_dims;
259}
260
261dims_t desc_t::strides() const {
262 dims_t strides {sd, sh, sw};
263 return dims_t(strides.begin() + (5 - ndims), strides.end());
264}
265
266dims_t desc_t::dilations() const {
267 dims_t dilations {dd, dh, dw};
268 return dims_t(dilations.begin() + (5 - ndims), dilations.end());
269}
270
271dims_t desc_t::padding() const {
272 dims_t padding {pd, ph, pw};
273 return dims_t(padding.begin() + (5 - ndims), padding.end());
274}
275
276dims_t desc_t::padding_r() const {
277 dims_t padding_r {pd_r, ph_r, pw_r};
278 return dims_t(padding_r.begin() + (5 - ndims), padding_r.end());
279}
280
281int64_t desc_t::desc_nelems(int arg, int mask) const {
282 dims_t dims;
283 switch (arg) {
284 case DNNL_ARG_SRC: dims = src_dims(); break;
285 case DNNL_ARG_WEIGHTS: dims = wei_dims(); break;
286 case DNNL_ARG_DST: dims = dst_dims(); break;
287 default: assert(!"unsupported arg");
288 }
289
290 int64_t nelems = 1;
291 for (int d = 0; d < ndims; d++) {
292 nelems *= (mask & (1 << d)) ? dims[d] : 1;
293 }
294 return nelems;
295}
296
297void prb_t::count_ops() {
298 if (ops > 0) return;
299
300 double sp_ops = 0;
301 for_(int64_t od = 0; od < this->od; ++od)
302 for_(int64_t oh = 0; oh < this->oh; ++oh)
303 for (int64_t ow = 0; ow < this->ow; ++ow) {
304 for (int64_t kd = 0; kd < this->kd; ++kd) {
305 const int64_t id = od * this->sd - this->pd + kd * (this->dd + 1);
306 if (id < 0 || id >= this->id) continue;
307 for (int64_t kh = 0; kh < this->kh; ++kh) {
308 const int64_t ih
309 = oh * this->sh - this->ph + kh * (this->dh + 1);
310 if (ih < 0 || ih >= this->ih) continue;
311 for (int64_t kw = 0; kw < this->kw; ++kw) {
312 const int64_t iw
313 = ow * this->sw - this->pw + kw * (this->dw + 1);
314 if (iw < 0 || iw >= this->iw) continue;
315 sp_ops += 1;
316 }
317 }
318 }
319 }
320
321 ops = 2 * this->mb * this->oc * this->ic / this->g * sp_ops;
322}
323
324float *prb_t::generate_scales(int arg) const {
325 const auto &scales = attr.scales;
326 if (scales.is_def()) return nullptr;
327
328 const auto &e = scales.get(arg);
329 if (e.policy == policy_t::COMMON) {
330 float *s = (float *)zmalloc(sizeof(float), 4);
331 SAFE_V(s != nullptr ? OK : FAIL);
332 s[0] = e.scale;
333 return s;
334 }
335
336 assert(e.policy == policy_t::PER_OC);
337 auto mask = attr_t::get_default_mask(e.policy, arg);
338 if (arg == DNNL_ARG_WEIGHTS && has_groups) mask = (1 << mask) + 1;
339 int64_t s_nelems = desc_nelems(arg, mask);
340
341 float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64);
342 SAFE_V(s != nullptr ? OK : FAIL);
343
344 const float K = 32;
345 /* scale in [1/K .. K], with starting point at e.scale */
346 float s_val[2] = {e.scale, e.scale / 2};
347 for (int64_t i = 0; i < s_nelems; ++i) {
348 int64_t si = i % 2; // 0 -> left, 1 -> right
349 s[i] = s_val[si];
350 if (si == 0) {
351 s_val[si] /= 2.;
352 // turn around to become ~K
353 if (s_val[si] < 1. / K) s_val[si] *= K * K;
354 } else {
355 s_val[si] *= 2.;
356 // turn around to become ~K
357 if (s_val[si] > K) s_val[si] /= K * K;
358 }
359 }
360 return s;
361}
362
363int32_t *prb_t::generate_zero_points(int arg) const {
364 const auto &zp = attr.zero_points;
365 if (zp.is_def(arg)) return nullptr;
366
367 const auto &e = zp.get(arg);
368 const auto mask = attr_t::get_default_mask(e.policy);
369 int64_t zp_nelems = desc_nelems(arg, mask);
370
371 int32_t *ptr = (int32_t *)zmalloc(sizeof(int32_t) * zp_nelems, 64);
372 SAFE_V(ptr != nullptr ? OK : FAIL);
373
374 for (int i = 0; i < zp_nelems; ++i)
375 ptr[i] = e.value + i % 3;
376
377 return ptr;
378}
379
380std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
381 dump_global_params(s);
382 settings_t def;
383
384 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
385 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
386 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
387 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
388 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
389 if (canonical || prb.alg != def.alg[0])
390 s << "--alg=" << alg2str(prb.alg) << " ";
391
392 s << prb.attr;
393 if (canonical || prb.ctx_init != def.ctx_init[0])
394 s << "--ctx-init=" << prb.ctx_init << " ";
395 if (canonical || prb.ctx_exe != def.ctx_exe[0])
396 s << "--ctx-exe=" << prb.ctx_exe << " ";
397
398 s << static_cast<const desc_t &>(prb);
399
400 return s;
401}
402
403} // namespace conv
404