1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdio.h> |
18 | #include <stdlib.h> |
19 | #include <string.h> |
20 | |
21 | #include "oneapi/dnnl/dnnl.h" |
22 | |
23 | #include "dnnl_common.hpp" |
24 | #include "dnnl_debug.hpp" |
25 | |
26 | #include "ip/ip.hpp" |
27 | |
28 | namespace ip { |
29 | |
30 | float *prb_t::generate_scales(int arg) const { |
31 | const auto &scales = attr.scales; |
32 | if (scales.is_def()) return nullptr; |
33 | |
34 | const auto &e = scales.get(arg); |
35 | if (e.policy == policy_t::COMMON) { |
36 | float *s = (float *)zmalloc(sizeof(float), 4); |
37 | SAFE_V(s != nullptr ? OK : FAIL); |
38 | s[0] = e.scale; |
39 | return s; |
40 | } |
41 | |
42 | assert(e.policy == policy_t::PER_OC); |
43 | const auto mask = attr_t::get_default_mask(e.policy, arg); |
44 | int64_t s_nelems = desc_nelems(arg, mask); |
45 | |
46 | float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64); |
47 | SAFE_V(s != nullptr ? OK : FAIL); |
48 | |
49 | const float K = 32; |
50 | /* scale in [1/K .. K], with starting point at e.scale */ |
51 | float s_val[2] = {e.scale, e.scale / 2}; |
52 | for (int64_t i = 0; i < oc; ++i) { |
53 | int64_t si = i % 2; // 0 -> left, 1 -> right |
54 | s[i] = s_val[si]; |
55 | if (si == 0) { |
56 | s_val[si] /= 2.; |
57 | // turn around to become ~K |
58 | if (s_val[si] < 1. / K) s_val[si] *= K * K; |
59 | } else { |
60 | s_val[si] *= 2.; |
61 | // turn around to become ~K |
62 | if (s_val[si] > K) s_val[si] /= K * K; |
63 | } |
64 | } |
65 | return s; |
66 | } |
67 | |
68 | int str2desc(desc_t *desc, const char *str) { |
69 | // Canonical form: mbXicXidXihXiwXocXnS, |
70 | // where |
71 | // X is integer |
72 | // S is string |
73 | // note: symbol `_` is ignored. |
74 | // Cubic/square shapes are supported by specifying just highest dimension. |
75 | |
76 | desc_t d {0}; |
77 | d.mb = 2; |
78 | |
79 | const char *s = str; |
80 | assert(s); |
81 | |
82 | #define CASE_NN(prb, c) \ |
83 | do { \ |
84 | if (!strncmp(prb, s, strlen(prb))) { \ |
85 | ok = 1; \ |
86 | s += strlen(prb); \ |
87 | char *end_s; \ |
88 | d.c = strtol(s, &end_s, 10); \ |
89 | s += (end_s - s); \ |
90 | if (d.c < 0) return FAIL; \ |
91 | /* printf("@@@debug: %s: %d\n", prb, d. c); */ \ |
92 | } \ |
93 | } while (0) |
94 | #define CASE_N(c) CASE_NN(#c, c) |
95 | while (*s) { |
96 | int ok = 0; |
97 | CASE_N(mb); |
98 | CASE_N(ic); |
99 | CASE_N(ih); |
100 | CASE_N(iw); |
101 | CASE_N(id); |
102 | CASE_N(oc); |
103 | if (*s == 'n') { |
104 | d.name = s + 1; |
105 | break; |
106 | } |
107 | if (*s == '_') ++s; |
108 | if (!ok) return FAIL; |
109 | } |
110 | #undef CASE_NN |
111 | #undef CASE_N |
112 | |
113 | if (d.ic == 0 || d.oc == 0) return FAIL; |
114 | |
115 | if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL; |
116 | |
117 | *desc = d; |
118 | |
119 | return OK; |
120 | } |
121 | |
122 | std::ostream &operator<<(std::ostream &s, const desc_t &d) { |
123 | bool print_d = true, print_h = true, print_w = true; |
124 | print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw}); |
125 | |
126 | if (canonical || d.mb != 2) s << "mb" << d.mb; |
127 | |
128 | s << "ic" << d.ic; |
129 | |
130 | if (print_d) s << "id" << d.id; |
131 | if (print_h) s << "ih" << d.ih; |
132 | if (print_w) s << "iw" << d.iw; |
133 | |
134 | s << "oc" << d.oc; |
135 | |
136 | if (!d.name.empty()) s << "n" << d.name; |
137 | |
138 | return s; |
139 | } |
140 | |
141 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
142 | dump_global_params(s); |
143 | settings_t def; |
144 | |
145 | if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " " ; |
146 | if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " " ; |
147 | if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " " ; |
148 | if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " " ; |
149 | if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " " ; |
150 | |
151 | s << prb.attr; |
152 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
153 | s << "--ctx-init=" << prb.ctx_init << " " ; |
154 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
155 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
156 | |
157 | s << static_cast<const desc_t &>(prb); |
158 | |
159 | return s; |
160 | } |
161 | |
162 | dims_t desc_t::src_dims() const { |
163 | dims_t src_dims {mb, ic, id, ih, iw}; |
164 | for (int d = 0; d < 5 - ndims; ++d) { |
165 | src_dims.erase(src_dims.begin() + 2); |
166 | } |
167 | |
168 | return src_dims; |
169 | } |
170 | |
171 | dims_t desc_t::wei_dims() const { |
172 | dims_t wei_dims {oc, ic, id, ih, iw}; |
173 | for (int d = 0; d < 5 - ndims; ++d) { |
174 | wei_dims.erase(wei_dims.begin() + 2); |
175 | } |
176 | |
177 | return wei_dims; |
178 | } |
179 | |
180 | dims_t desc_t::bia_dims() const { |
181 | dims_t bia_dims {oc}; |
182 | return bia_dims; |
183 | } |
184 | |
185 | dims_t desc_t::dst_dims() const { |
186 | dims_t dst_dims {mb, oc}; |
187 | return dst_dims; |
188 | } |
189 | |
190 | int64_t desc_t::desc_nelems(int arg, int mask) const { |
191 | dims_t dims; |
192 | switch (arg) { |
193 | case DNNL_ARG_SRC: dims = src_dims(); break; |
194 | case DNNL_ARG_WEIGHTS: dims = wei_dims(); break; |
195 | case DNNL_ARG_DST: dims = dst_dims(); break; |
196 | default: assert(!"unsupported arg" ); |
197 | } |
198 | |
199 | int64_t nelems = 1; |
200 | for (int d = 0; d < ndims; d++) { |
201 | nelems *= (mask & (1 << d)) ? dims[d] : 1; |
202 | } |
203 | return nelems; |
204 | } |
205 | |
206 | } // namespace ip |
207 | |