1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <stdlib.h> |
18 | #include <string.h> |
19 | |
20 | #include "common.hpp" |
21 | #include "dnnl_common.hpp" |
22 | #include "dnnl_memory.hpp" |
23 | #include "utils/parser.hpp" |
24 | |
25 | #include "self/self.hpp" |
26 | |
27 | using namespace parser; |
28 | |
29 | namespace self { |
30 | |
31 | using pk_t = attr_t::post_ops_t::kind_t; |
32 | |
33 | static int check_simple_enums() { |
34 | /* attr::post_ops::kind */ |
35 | using p = attr_t::post_ops_t; |
36 | SELF_CHECK_CASE_STR_EQ(p::kind2str(p::kind_t::SUM), "sum" ); |
37 | SELF_CHECK_CASE_STR_EQ(p::kind2str(p::kind_t::RELU), "relu" ); |
38 | |
39 | SELF_CHECK_EQ(p::str2kind("sum" ), p::kind_t::SUM); |
40 | SELF_CHECK_EQ(p::str2kind("SuM" ), p::kind_t::SUM); |
41 | |
42 | SELF_CHECK_EQ(p::str2kind("relu" ), p::kind_t::RELU); |
43 | SELF_CHECK_EQ(p::str2kind("ReLU" ), p::kind_t::RELU); |
44 | |
45 | return OK; |
46 | } |
47 | |
48 | static int check_attr2str() { |
49 | attr_t attr; |
50 | SELF_CHECK_EQ(attr.is_def(), true); |
51 | |
52 | SELF_CHECK_PRINT_EQ(attr, "" ); |
53 | |
54 | attr = attr_t(); |
55 | attr.zero_points.set(DNNL_ARG_SRC, policy_t::COMMON, 1, false); |
56 | SELF_CHECK_PRINT_EQ(attr, "--attr-zero-points=src:common:1 " ); |
57 | |
58 | attr.zero_points.set(DNNL_ARG_SRC, policy_t::PER_DIM_0, 3, false); |
59 | attr.zero_points.set(DNNL_ARG_WEIGHTS, {policy_t::PER_DIM_1, 2, true}); |
60 | SELF_CHECK_PRINT_EQ2(attr, |
61 | "--attr-zero-points=src:per_dim_0:3+wei:per_dim_1:2* " , |
62 | "--attr-zero-points=wei:per_dim_1:2*+src:per_dim_0:3 " ); |
63 | |
64 | attr = attr_t(); |
65 | attr.scales.set(DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.3)); |
66 | SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.3 " ); |
67 | |
68 | attr = attr_t(); |
69 | attr.scales.set( |
70 | DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.3, true)); |
71 | SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.3* " ); |
72 | |
73 | attr.scales.set(DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.2)); |
74 | attr.scales.set(DNNL_ARG_SRC_1, attr_t::scale_t(policy_t::COMMON, 3)); |
75 | SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.2+src1:common:3 " ); |
76 | |
77 | return OK; |
78 | } |
79 | |
80 | static int check_attr() { |
81 | #define SELF_CHECK_OSCALE(os, os_policy, os_scale, os_runtime) \ |
82 | do { \ |
83 | SELF_CHECK_EQ((os).policy, policy_t::os_policy); \ |
84 | SELF_CHECK_EQ((os).scale, os_scale); \ |
85 | SELF_CHECK_EQ((os).runtime, os_runtime); \ |
86 | } while (0) |
87 | |
88 | #define SELF_CHECK_ATTR_ZP(zp, arg, zero_points_value, zero_points_runtime) \ |
89 | do { \ |
90 | const auto &entry = (zp).get(arg); \ |
91 | SELF_CHECK_EQ(entry.value, zero_points_value); \ |
92 | SELF_CHECK_EQ(entry.runtime, zero_points_runtime); \ |
93 | } while (0) |
94 | |
95 | { |
96 | std::vector<attr_t::zero_points_t> zp; |
97 | SELF_CHECK_EQ(parse_attr_zero_points(zp, |
98 | "--attr-zero-points=src:common:0+dst:common:-2*" ), |
99 | true); |
100 | SELF_CHECK_EQ(zp.size(), 1); |
101 | SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_SRC, 0, false); |
102 | SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_WEIGHTS, 0, false); |
103 | SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_DST, -2, true); |
104 | } |
105 | |
106 | { |
107 | std::vector<attr_t::arg_scales_t> sc; |
108 | SELF_CHECK_EQ( |
109 | parse_attr_scales(sc, "--attr-scales=src1:common:1.5" ), true); |
110 | SELF_CHECK_EQ(sc.size(), 1); |
111 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).policy, policy_t::COMMON); |
112 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).scale, 1.5); |
113 | } |
114 | |
115 | { |
116 | std::vector<attr_t::arg_scales_t> sc; |
117 | SELF_CHECK_EQ(parse_attr_scales(sc, |
118 | "--attr-scales=src:common:2.5+src1:common:1.5" ), |
119 | true); |
120 | SELF_CHECK_EQ(sc.size(), 1); |
121 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_0).policy, policy_t::COMMON); |
122 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_0).scale, 2.5); |
123 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).policy, policy_t::COMMON); |
124 | SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).scale, 1.5); |
125 | } |
126 | |
127 | // depthwise conv section |
128 | { |
129 | std::vector<attr_t::post_ops_t> po; |
130 | auto st = parse_attr_post_ops(po, "--attr-post-ops=dw_k3s1p1" ); |
131 | SELF_CHECK_EQ(st, true); |
132 | SELF_CHECK_EQ(po[0].len(), 1); |
133 | const auto &e = po[0].entry[0]; |
134 | SELF_CHECK_EQ(e.kind, pk_t::DW_K3S1P1); |
135 | const auto &ce = e.convolution; |
136 | SELF_CHECK_EQ(ce.stride, 1); |
137 | SELF_CHECK_EQ(ce.dst_dt, dnnl_f32); |
138 | SELF_CHECK_OSCALE(ce.wei_scale, COMMON, 1.f, false); |
139 | SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 1.f, false); |
140 | } |
141 | |
142 | { |
143 | std::vector<attr_t::post_ops_t> po; |
144 | auto st = parse_attr_post_ops(po, |
145 | "--attr-post-ops=relu:0.5+dw_k3s2p1:s8:per_oc:2*+linear:2:1" ); |
146 | SELF_CHECK_EQ(st, true); |
147 | SELF_CHECK_EQ(po[0].len(), 3); |
148 | auto &e = po[0].entry[0]; |
149 | SELF_CHECK_EQ(e.kind, pk_t::RELU); |
150 | auto &ee = e.eltwise; |
151 | SELF_CHECK_EQ(ee.alg, dnnl_eltwise_relu); |
152 | SELF_CHECK_EQ(ee.alpha, 0.5f); |
153 | SELF_CHECK_EQ(ee.beta, 0.f); |
154 | |
155 | e = po[0].entry[1]; |
156 | SELF_CHECK_EQ(e.kind, pk_t::DW_K3S2P1); |
157 | const auto &ce = e.convolution; |
158 | SELF_CHECK_EQ(ce.stride, 2); |
159 | SELF_CHECK_EQ(ce.dst_dt, dnnl_s8); |
160 | SELF_CHECK_OSCALE(ce.wei_scale, PER_OC, 2.f, true); |
161 | SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 1.f, false); |
162 | |
163 | e = po[0].entry[2]; |
164 | SELF_CHECK_EQ(e.kind, pk_t::LINEAR); |
165 | ee = e.eltwise; |
166 | SELF_CHECK_EQ(ee.alg, dnnl_eltwise_linear); |
167 | SELF_CHECK_EQ(ee.alpha, 2.f); |
168 | SELF_CHECK_EQ(ee.beta, 1.f); |
169 | } |
170 | |
171 | { |
172 | std::vector<attr_t::post_ops_t> po; |
173 | auto st = parse_attr_post_ops( |
174 | po, "--attr-post-ops=dw_k3s1p1:s8:per_oc:2*:common:4*" ); |
175 | SELF_CHECK_EQ(st, true); |
176 | SELF_CHECK_EQ(po[0].len(), 1); |
177 | const auto &e = po[0].entry[0]; |
178 | SELF_CHECK_EQ(e.kind, pk_t::DW_K3S1P1); |
179 | const auto &ce = e.convolution; |
180 | SELF_CHECK_EQ(ce.stride, 1); |
181 | SELF_CHECK_EQ(ce.dst_dt, dnnl_s8); |
182 | SELF_CHECK_OSCALE(ce.wei_scale, PER_OC, 2.f, true); |
183 | SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 4.f, true); |
184 | } |
185 | |
186 | #undef SELF_CHECK_OSCALE |
187 | #undef SELF_CHECK_ATTR_OSCALE |
188 | #undef SELF_CHECK_ATTR_ZP |
189 | |
190 | return OK; |
191 | } |
192 | |
193 | void append_sum(attr_t::post_ops_t &po, float ascale = 1.f, |
194 | int32_t zero_point = 0, dnnl_data_type_t adt = dnnl_data_type_undef) { |
195 | attr_t::post_ops_t::entry_t e(pk_t::SUM); |
196 | e.sum.scale = ascale; |
197 | e.sum.zero_point = zero_point; |
198 | e.sum.dt = adt; |
199 | po.entry.push_back(e); |
200 | } |
201 | |
202 | void append_convolution(attr_t::post_ops_t &po, pk_t akind, |
203 | dnnl_data_type_t adst_dt = dnnl_f32, |
204 | policy_t apolicy = policy_t::COMMON, float ascale = 1.f) { |
205 | attr_t::post_ops_t::entry_t e(akind); |
206 | e.convolution.stride = e.kind == pk_t::DW_K3S1P1 ? 1 : 2; |
207 | e.convolution.dst_dt = adst_dt; |
208 | e.convolution.wei_scale = attr_t::scale_t(apolicy, ascale); |
209 | po.entry.push_back(e); |
210 | } |
211 | |
212 | void append_eltwise(attr_t::post_ops_t &po, pk_t akind, float aalpha = 0.f, |
213 | float abeta = 0.f) { |
214 | attr_t::post_ops_t::entry_t e(akind); |
215 | e.eltwise.alg = attr_t::post_ops_t::kind2dnnl_kind(akind); |
216 | e.eltwise.alpha = aalpha; |
217 | e.eltwise.beta = abeta; |
218 | po.entry.push_back(e); |
219 | } |
220 | |
221 | static int check_post_ops2str() { |
222 | attr_t::post_ops_t po; |
223 | SELF_CHECK_EQ(po.is_def(), true); |
224 | SELF_CHECK_PRINT_EQ(po, "" ); |
225 | |
226 | append_sum(po); |
227 | SELF_CHECK_EQ(po.len(), 1); |
228 | SELF_CHECK_PRINT_EQ(po, "sum" ); |
229 | |
230 | append_eltwise(po, pk_t::RELU); |
231 | SELF_CHECK_EQ(po.len(), 2); |
232 | SELF_CHECK_PRINT_EQ(po, "sum+relu" ); |
233 | |
234 | append_sum(po, 2.f, 1, dnnl_s8); |
235 | SELF_CHECK_EQ(po.len(), 3); |
236 | SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8" ); |
237 | |
238 | append_eltwise(po, pk_t::LINEAR, 5.f, 10.f); |
239 | SELF_CHECK_EQ(po.len(), 4); |
240 | SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8+linear:5:10" ); |
241 | |
242 | append_convolution(po, pk_t::DW_K3S1P1); |
243 | SELF_CHECK_EQ(po.len(), 5); |
244 | SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8+linear:5:10+dw_k3s1p1" ); |
245 | |
246 | append_convolution(po, pk_t::DW_K3S2P1, dnnl_s32, policy_t::PER_OC, 2.f); |
247 | SELF_CHECK_EQ(po.len(), 6); |
248 | SELF_CHECK_PRINT_EQ(po, |
249 | "sum+relu+sum:2:1:s8+linear:5:10+dw_k3s1p1+dw_k3s2p1:s32:per_oc:" |
250 | "2" ); |
251 | |
252 | return OK; |
253 | } |
254 | |
255 | static int check_str2post_ops() { |
256 | attr_t::post_ops_t ops; |
257 | |
258 | SELF_CHECK_EQ(ops.is_def(), true); |
259 | |
260 | auto quick = [&](int len) -> int { |
261 | for (int i = 0; i < 2; ++i) { |
262 | if (2 * i + 0 >= len) return OK; |
263 | SELF_CHECK_EQ(ops.entry[2 * i + 0].kind, attr_t::post_ops_t::SUM); |
264 | SELF_CHECK_EQ(ops.entry[2 * i + 0].sum.scale, 2. + i); |
265 | if (2 * i + 1 >= len) return OK; |
266 | SELF_CHECK_EQ(ops.entry[2 * i + 1].kind, attr_t::post_ops_t::RELU); |
267 | SELF_CHECK_EQ(ops.entry[2 * i + 1].eltwise.alpha, 0.); |
268 | SELF_CHECK_EQ(ops.entry[2 * i + 1].eltwise.beta, 0.); |
269 | } |
270 | return OK; |
271 | }; |
272 | |
273 | ops.from_str("" ); |
274 | SELF_CHECK_EQ(ops.is_def(), true); |
275 | |
276 | ops.from_str("sum:2" ); |
277 | SELF_CHECK_EQ(quick(1), OK); |
278 | |
279 | ops.from_str("sum:2+relu" ); |
280 | SELF_CHECK_EQ(quick(2), OK); |
281 | |
282 | ops.from_str("sum:2+relu+sum:3" ); |
283 | SELF_CHECK_EQ(quick(3), OK); |
284 | |
285 | ops.from_str("sum:2+relu+sum:3+relu" ); |
286 | SELF_CHECK_EQ(quick(4), OK); |
287 | |
288 | return OK; |
289 | } |
290 | |
291 | static int check_tags() { |
292 | for (int tag_ = dnnl_format_tag_undef; tag_ != dnnl_format_tag_last; |
293 | tag_++) { |
294 | dnnl_format_tag_t format_tag = (dnnl_format_tag_t)tag_; |
295 | const char *str_tag = fmt_tag2str(format_tag); |
296 | int ndims = 1; |
297 | for (char c = (char)('a' + DNNL_MAX_NDIMS - 1); c >= 'a'; c--) { |
298 | if (strchr(str_tag, c)) { |
299 | ndims = c - 'a' + 1; |
300 | break; |
301 | } |
302 | } |
303 | const dnnl_dims_t dims |
304 | = {7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47}; |
305 | auto md_from_str = dnn_mem_t::init_md(ndims, dims, dnnl_f32, str_tag); |
306 | |
307 | dnnl_memory_desc_t md_from_tag; |
308 | DNN_SAFE(dnnl_memory_desc_create_with_tag( |
309 | &md_from_tag, ndims, dims, dnnl_f32, format_tag), |
310 | CRIT); |
311 | int eq = dnnl_memory_desc_equal(md_from_tag, md_from_str); |
312 | SELF_CHECK_EQ(eq, 1); |
313 | |
314 | DNN_SAFE(dnnl_memory_desc_destroy(md_from_tag), CRIT); |
315 | } |
316 | |
317 | return OK; |
318 | } |
319 | |
320 | static int check_trim_tags() { |
321 | { |
322 | std::string tag = "BA16a16b4a" ; |
323 | std::string ndims_trimmed_tag = trim_tag(tag, 1); |
324 | SELF_CHECK_EQ(true, ndims_trimmed_tag == "A16a4a" ); |
325 | } |
326 | { |
327 | std::string tag = "BA16a16b4a" ; |
328 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 2); |
329 | SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a" ); |
330 | } |
331 | { |
332 | std::string tag = "abcd" ; |
333 | std::string ndims_trimmed_tag = trim_tag(tag, 2); |
334 | SELF_CHECK_EQ(true, ndims_trimmed_tag == "ab" ); |
335 | } |
336 | { |
337 | std::string tag = "abcd" ; |
338 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 10); |
339 | SELF_CHECK_EQ(true, masked_trimmed_tag == "ab" ); |
340 | } |
341 | { |
342 | std::string tag = "abcd" ; |
343 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 12); |
344 | SELF_CHECK_EQ(true, masked_trimmed_tag == "ab" ); |
345 | } |
346 | { |
347 | std::string tag = "aBcd16b" ; |
348 | std::string ndims_trimmed_tag = trim_tag(tag, 2); |
349 | SELF_CHECK_EQ(true, ndims_trimmed_tag == "aB16b" ); |
350 | } |
351 | { |
352 | std::string tag = "aBcd16b" ; |
353 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 2); |
354 | SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a" ); |
355 | } |
356 | { |
357 | std::string tag = "BA16a16b4a" ; |
358 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 1); |
359 | SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a4a" ); |
360 | } |
361 | { |
362 | std::string tag = "BADC2c16a8d16b4a" ; |
363 | std::string masked_trimmed_tag = trim_tag_by_mask(tag, 14); |
364 | SELF_CHECK_EQ(true, masked_trimmed_tag == "ACB2b8c16a" ); |
365 | } |
366 | |
367 | return OK; |
368 | } |
369 | |
370 | static int check_skip_impl() { |
371 | skip_impl = "gemm" ; |
372 | SELF_CHECK_EQ(true, maybe_skip("x64:gemm:jit" )); |
373 | |
374 | skip_impl = "ref,x64:gemm" ; |
375 | SELF_CHECK_EQ(true, maybe_skip("x64:gemm:jit" )); |
376 | |
377 | skip_impl = "this,finds,nothing" ; |
378 | SELF_CHECK_EQ(false, maybe_skip("x64:gemm:jit" )); |
379 | |
380 | return OK; |
381 | } |
382 | |
383 | void common() { |
384 | RUN(check_simple_enums()); |
385 | RUN(check_attr2str()); |
386 | RUN(check_attr()); |
387 | RUN(check_post_ops2str()); |
388 | RUN(check_str2post_ops()); |
389 | RUN(check_tags()); |
390 | RUN(check_trim_tags()); |
391 | RUN(check_skip_impl()); |
392 | } |
393 | |
394 | } // namespace self |
395 | |