1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdlib.h>
18#include <string.h>
19
20#include "common.hpp"
21#include "dnnl_common.hpp"
22#include "dnnl_memory.hpp"
23#include "utils/parser.hpp"
24
25#include "self/self.hpp"
26
27using namespace parser;
28
29namespace self {
30
31using pk_t = attr_t::post_ops_t::kind_t;
32
33static int check_simple_enums() {
34 /* attr::post_ops::kind */
35 using p = attr_t::post_ops_t;
36 SELF_CHECK_CASE_STR_EQ(p::kind2str(p::kind_t::SUM), "sum");
37 SELF_CHECK_CASE_STR_EQ(p::kind2str(p::kind_t::RELU), "relu");
38
39 SELF_CHECK_EQ(p::str2kind("sum"), p::kind_t::SUM);
40 SELF_CHECK_EQ(p::str2kind("SuM"), p::kind_t::SUM);
41
42 SELF_CHECK_EQ(p::str2kind("relu"), p::kind_t::RELU);
43 SELF_CHECK_EQ(p::str2kind("ReLU"), p::kind_t::RELU);
44
45 return OK;
46}
47
48static int check_attr2str() {
49 attr_t attr;
50 SELF_CHECK_EQ(attr.is_def(), true);
51
52 SELF_CHECK_PRINT_EQ(attr, "");
53
54 attr = attr_t();
55 attr.zero_points.set(DNNL_ARG_SRC, policy_t::COMMON, 1, false);
56 SELF_CHECK_PRINT_EQ(attr, "--attr-zero-points=src:common:1 ");
57
58 attr.zero_points.set(DNNL_ARG_SRC, policy_t::PER_DIM_0, 3, false);
59 attr.zero_points.set(DNNL_ARG_WEIGHTS, {policy_t::PER_DIM_1, 2, true});
60 SELF_CHECK_PRINT_EQ2(attr,
61 "--attr-zero-points=src:per_dim_0:3+wei:per_dim_1:2* ",
62 "--attr-zero-points=wei:per_dim_1:2*+src:per_dim_0:3 ");
63
64 attr = attr_t();
65 attr.scales.set(DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.3));
66 SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.3 ");
67
68 attr = attr_t();
69 attr.scales.set(
70 DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.3, true));
71 SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.3* ");
72
73 attr.scales.set(DNNL_ARG_SRC_0, attr_t::scale_t(policy_t::COMMON, 2.2));
74 attr.scales.set(DNNL_ARG_SRC_1, attr_t::scale_t(policy_t::COMMON, 3));
75 SELF_CHECK_PRINT_EQ(attr, "--attr-scales=src:common:2.2+src1:common:3 ");
76
77 return OK;
78}
79
80static int check_attr() {
81#define SELF_CHECK_OSCALE(os, os_policy, os_scale, os_runtime) \
82 do { \
83 SELF_CHECK_EQ((os).policy, policy_t::os_policy); \
84 SELF_CHECK_EQ((os).scale, os_scale); \
85 SELF_CHECK_EQ((os).runtime, os_runtime); \
86 } while (0)
87
88#define SELF_CHECK_ATTR_ZP(zp, arg, zero_points_value, zero_points_runtime) \
89 do { \
90 const auto &entry = (zp).get(arg); \
91 SELF_CHECK_EQ(entry.value, zero_points_value); \
92 SELF_CHECK_EQ(entry.runtime, zero_points_runtime); \
93 } while (0)
94
95 {
96 std::vector<attr_t::zero_points_t> zp;
97 SELF_CHECK_EQ(parse_attr_zero_points(zp,
98 "--attr-zero-points=src:common:0+dst:common:-2*"),
99 true);
100 SELF_CHECK_EQ(zp.size(), 1);
101 SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_SRC, 0, false);
102 SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_WEIGHTS, 0, false);
103 SELF_CHECK_ATTR_ZP(zp[0], DNNL_ARG_DST, -2, true);
104 }
105
106 {
107 std::vector<attr_t::arg_scales_t> sc;
108 SELF_CHECK_EQ(
109 parse_attr_scales(sc, "--attr-scales=src1:common:1.5"), true);
110 SELF_CHECK_EQ(sc.size(), 1);
111 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).policy, policy_t::COMMON);
112 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).scale, 1.5);
113 }
114
115 {
116 std::vector<attr_t::arg_scales_t> sc;
117 SELF_CHECK_EQ(parse_attr_scales(sc,
118 "--attr-scales=src:common:2.5+src1:common:1.5"),
119 true);
120 SELF_CHECK_EQ(sc.size(), 1);
121 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_0).policy, policy_t::COMMON);
122 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_0).scale, 2.5);
123 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).policy, policy_t::COMMON);
124 SELF_CHECK_EQ(sc[0].get(DNNL_ARG_SRC_1).scale, 1.5);
125 }
126
127 // depthwise conv section
128 {
129 std::vector<attr_t::post_ops_t> po;
130 auto st = parse_attr_post_ops(po, "--attr-post-ops=dw_k3s1p1");
131 SELF_CHECK_EQ(st, true);
132 SELF_CHECK_EQ(po[0].len(), 1);
133 const auto &e = po[0].entry[0];
134 SELF_CHECK_EQ(e.kind, pk_t::DW_K3S1P1);
135 const auto &ce = e.convolution;
136 SELF_CHECK_EQ(ce.stride, 1);
137 SELF_CHECK_EQ(ce.dst_dt, dnnl_f32);
138 SELF_CHECK_OSCALE(ce.wei_scale, COMMON, 1.f, false);
139 SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 1.f, false);
140 }
141
142 {
143 std::vector<attr_t::post_ops_t> po;
144 auto st = parse_attr_post_ops(po,
145 "--attr-post-ops=relu:0.5+dw_k3s2p1:s8:per_oc:2*+linear:2:1");
146 SELF_CHECK_EQ(st, true);
147 SELF_CHECK_EQ(po[0].len(), 3);
148 auto &e = po[0].entry[0];
149 SELF_CHECK_EQ(e.kind, pk_t::RELU);
150 auto &ee = e.eltwise;
151 SELF_CHECK_EQ(ee.alg, dnnl_eltwise_relu);
152 SELF_CHECK_EQ(ee.alpha, 0.5f);
153 SELF_CHECK_EQ(ee.beta, 0.f);
154
155 e = po[0].entry[1];
156 SELF_CHECK_EQ(e.kind, pk_t::DW_K3S2P1);
157 const auto &ce = e.convolution;
158 SELF_CHECK_EQ(ce.stride, 2);
159 SELF_CHECK_EQ(ce.dst_dt, dnnl_s8);
160 SELF_CHECK_OSCALE(ce.wei_scale, PER_OC, 2.f, true);
161 SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 1.f, false);
162
163 e = po[0].entry[2];
164 SELF_CHECK_EQ(e.kind, pk_t::LINEAR);
165 ee = e.eltwise;
166 SELF_CHECK_EQ(ee.alg, dnnl_eltwise_linear);
167 SELF_CHECK_EQ(ee.alpha, 2.f);
168 SELF_CHECK_EQ(ee.beta, 1.f);
169 }
170
171 {
172 std::vector<attr_t::post_ops_t> po;
173 auto st = parse_attr_post_ops(
174 po, "--attr-post-ops=dw_k3s1p1:s8:per_oc:2*:common:4*");
175 SELF_CHECK_EQ(st, true);
176 SELF_CHECK_EQ(po[0].len(), 1);
177 const auto &e = po[0].entry[0];
178 SELF_CHECK_EQ(e.kind, pk_t::DW_K3S1P1);
179 const auto &ce = e.convolution;
180 SELF_CHECK_EQ(ce.stride, 1);
181 SELF_CHECK_EQ(ce.dst_dt, dnnl_s8);
182 SELF_CHECK_OSCALE(ce.wei_scale, PER_OC, 2.f, true);
183 SELF_CHECK_OSCALE(ce.dst_scale, COMMON, 4.f, true);
184 }
185
186#undef SELF_CHECK_OSCALE
187#undef SELF_CHECK_ATTR_OSCALE
188#undef SELF_CHECK_ATTR_ZP
189
190 return OK;
191}
192
193void append_sum(attr_t::post_ops_t &po, float ascale = 1.f,
194 int32_t zero_point = 0, dnnl_data_type_t adt = dnnl_data_type_undef) {
195 attr_t::post_ops_t::entry_t e(pk_t::SUM);
196 e.sum.scale = ascale;
197 e.sum.zero_point = zero_point;
198 e.sum.dt = adt;
199 po.entry.push_back(e);
200}
201
202void append_convolution(attr_t::post_ops_t &po, pk_t akind,
203 dnnl_data_type_t adst_dt = dnnl_f32,
204 policy_t apolicy = policy_t::COMMON, float ascale = 1.f) {
205 attr_t::post_ops_t::entry_t e(akind);
206 e.convolution.stride = e.kind == pk_t::DW_K3S1P1 ? 1 : 2;
207 e.convolution.dst_dt = adst_dt;
208 e.convolution.wei_scale = attr_t::scale_t(apolicy, ascale);
209 po.entry.push_back(e);
210}
211
212void append_eltwise(attr_t::post_ops_t &po, pk_t akind, float aalpha = 0.f,
213 float abeta = 0.f) {
214 attr_t::post_ops_t::entry_t e(akind);
215 e.eltwise.alg = attr_t::post_ops_t::kind2dnnl_kind(akind);
216 e.eltwise.alpha = aalpha;
217 e.eltwise.beta = abeta;
218 po.entry.push_back(e);
219}
220
221static int check_post_ops2str() {
222 attr_t::post_ops_t po;
223 SELF_CHECK_EQ(po.is_def(), true);
224 SELF_CHECK_PRINT_EQ(po, "");
225
226 append_sum(po);
227 SELF_CHECK_EQ(po.len(), 1);
228 SELF_CHECK_PRINT_EQ(po, "sum");
229
230 append_eltwise(po, pk_t::RELU);
231 SELF_CHECK_EQ(po.len(), 2);
232 SELF_CHECK_PRINT_EQ(po, "sum+relu");
233
234 append_sum(po, 2.f, 1, dnnl_s8);
235 SELF_CHECK_EQ(po.len(), 3);
236 SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8");
237
238 append_eltwise(po, pk_t::LINEAR, 5.f, 10.f);
239 SELF_CHECK_EQ(po.len(), 4);
240 SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8+linear:5:10");
241
242 append_convolution(po, pk_t::DW_K3S1P1);
243 SELF_CHECK_EQ(po.len(), 5);
244 SELF_CHECK_PRINT_EQ(po, "sum+relu+sum:2:1:s8+linear:5:10+dw_k3s1p1");
245
246 append_convolution(po, pk_t::DW_K3S2P1, dnnl_s32, policy_t::PER_OC, 2.f);
247 SELF_CHECK_EQ(po.len(), 6);
248 SELF_CHECK_PRINT_EQ(po,
249 "sum+relu+sum:2:1:s8+linear:5:10+dw_k3s1p1+dw_k3s2p1:s32:per_oc:"
250 "2");
251
252 return OK;
253}
254
255static int check_str2post_ops() {
256 attr_t::post_ops_t ops;
257
258 SELF_CHECK_EQ(ops.is_def(), true);
259
260 auto quick = [&](int len) -> int {
261 for (int i = 0; i < 2; ++i) {
262 if (2 * i + 0 >= len) return OK;
263 SELF_CHECK_EQ(ops.entry[2 * i + 0].kind, attr_t::post_ops_t::SUM);
264 SELF_CHECK_EQ(ops.entry[2 * i + 0].sum.scale, 2. + i);
265 if (2 * i + 1 >= len) return OK;
266 SELF_CHECK_EQ(ops.entry[2 * i + 1].kind, attr_t::post_ops_t::RELU);
267 SELF_CHECK_EQ(ops.entry[2 * i + 1].eltwise.alpha, 0.);
268 SELF_CHECK_EQ(ops.entry[2 * i + 1].eltwise.beta, 0.);
269 }
270 return OK;
271 };
272
273 ops.from_str("");
274 SELF_CHECK_EQ(ops.is_def(), true);
275
276 ops.from_str("sum:2");
277 SELF_CHECK_EQ(quick(1), OK);
278
279 ops.from_str("sum:2+relu");
280 SELF_CHECK_EQ(quick(2), OK);
281
282 ops.from_str("sum:2+relu+sum:3");
283 SELF_CHECK_EQ(quick(3), OK);
284
285 ops.from_str("sum:2+relu+sum:3+relu");
286 SELF_CHECK_EQ(quick(4), OK);
287
288 return OK;
289}
290
291static int check_tags() {
292 for (int tag_ = dnnl_format_tag_undef; tag_ != dnnl_format_tag_last;
293 tag_++) {
294 dnnl_format_tag_t format_tag = (dnnl_format_tag_t)tag_;
295 const char *str_tag = fmt_tag2str(format_tag);
296 int ndims = 1;
297 for (char c = (char)('a' + DNNL_MAX_NDIMS - 1); c >= 'a'; c--) {
298 if (strchr(str_tag, c)) {
299 ndims = c - 'a' + 1;
300 break;
301 }
302 }
303 const dnnl_dims_t dims
304 = {7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47};
305 auto md_from_str = dnn_mem_t::init_md(ndims, dims, dnnl_f32, str_tag);
306
307 dnnl_memory_desc_t md_from_tag;
308 DNN_SAFE(dnnl_memory_desc_create_with_tag(
309 &md_from_tag, ndims, dims, dnnl_f32, format_tag),
310 CRIT);
311 int eq = dnnl_memory_desc_equal(md_from_tag, md_from_str);
312 SELF_CHECK_EQ(eq, 1);
313
314 DNN_SAFE(dnnl_memory_desc_destroy(md_from_tag), CRIT);
315 }
316
317 return OK;
318}
319
320static int check_trim_tags() {
321 {
322 std::string tag = "BA16a16b4a";
323 std::string ndims_trimmed_tag = trim_tag(tag, 1);
324 SELF_CHECK_EQ(true, ndims_trimmed_tag == "A16a4a");
325 }
326 {
327 std::string tag = "BA16a16b4a";
328 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 2);
329 SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a");
330 }
331 {
332 std::string tag = "abcd";
333 std::string ndims_trimmed_tag = trim_tag(tag, 2);
334 SELF_CHECK_EQ(true, ndims_trimmed_tag == "ab");
335 }
336 {
337 std::string tag = "abcd";
338 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 10);
339 SELF_CHECK_EQ(true, masked_trimmed_tag == "ab");
340 }
341 {
342 std::string tag = "abcd";
343 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 12);
344 SELF_CHECK_EQ(true, masked_trimmed_tag == "ab");
345 }
346 {
347 std::string tag = "aBcd16b";
348 std::string ndims_trimmed_tag = trim_tag(tag, 2);
349 SELF_CHECK_EQ(true, ndims_trimmed_tag == "aB16b");
350 }
351 {
352 std::string tag = "aBcd16b";
353 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 2);
354 SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a");
355 }
356 {
357 std::string tag = "BA16a16b4a";
358 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 1);
359 SELF_CHECK_EQ(true, masked_trimmed_tag == "A16a4a");
360 }
361 {
362 std::string tag = "BADC2c16a8d16b4a";
363 std::string masked_trimmed_tag = trim_tag_by_mask(tag, 14);
364 SELF_CHECK_EQ(true, masked_trimmed_tag == "ACB2b8c16a");
365 }
366
367 return OK;
368}
369
370static int check_skip_impl() {
371 skip_impl = "gemm";
372 SELF_CHECK_EQ(true, maybe_skip("x64:gemm:jit"));
373
374 skip_impl = "ref,x64:gemm";
375 SELF_CHECK_EQ(true, maybe_skip("x64:gemm:jit"));
376
377 skip_impl = "this,finds,nothing";
378 SELF_CHECK_EQ(false, maybe_skip("x64:gemm:jit"));
379
380 return OK;
381}
382
383void common() {
384 RUN(check_simple_enums());
385 RUN(check_attr2str());
386 RUN(check_attr());
387 RUN(check_post_ops2str());
388 RUN(check_str2post_ops());
389 RUN(check_tags());
390 RUN(check_trim_tags());
391 RUN(check_skip_impl());
392}
393
394} // namespace self
395