1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <stdio.h>
18#include <stdlib.h>
19#include <string.h>
20
21#include "oneapi/dnnl/dnnl.h"
22
23#include "dnnl_common.hpp"
24#include "dnnl_debug.hpp"
25
26#include "ip/ip.hpp"
27
28namespace ip {
29
30float *prb_t::generate_scales(int arg) const {
31 const auto &scales = attr.scales;
32 if (scales.is_def()) return nullptr;
33
34 const auto &e = scales.get(arg);
35 if (e.policy == policy_t::COMMON) {
36 float *s = (float *)zmalloc(sizeof(float), 4);
37 SAFE_V(s != nullptr ? OK : FAIL);
38 s[0] = e.scale;
39 return s;
40 }
41
42 assert(e.policy == policy_t::PER_OC);
43 const auto mask = attr_t::get_default_mask(e.policy, arg);
44 int64_t s_nelems = desc_nelems(arg, mask);
45
46 float *s = (float *)zmalloc(sizeof(float) * s_nelems, 64);
47 SAFE_V(s != nullptr ? OK : FAIL);
48
49 const float K = 32;
50 /* scale in [1/K .. K], with starting point at e.scale */
51 float s_val[2] = {e.scale, e.scale / 2};
52 for (int64_t i = 0; i < oc; ++i) {
53 int64_t si = i % 2; // 0 -> left, 1 -> right
54 s[i] = s_val[si];
55 if (si == 0) {
56 s_val[si] /= 2.;
57 // turn around to become ~K
58 if (s_val[si] < 1. / K) s_val[si] *= K * K;
59 } else {
60 s_val[si] *= 2.;
61 // turn around to become ~K
62 if (s_val[si] > K) s_val[si] /= K * K;
63 }
64 }
65 return s;
66}
67
68int str2desc(desc_t *desc, const char *str) {
69 // Canonical form: mbXicXidXihXiwXocXnS,
70 // where
71 // X is integer
72 // S is string
73 // note: symbol `_` is ignored.
74 // Cubic/square shapes are supported by specifying just highest dimension.
75
76 desc_t d {0};
77 d.mb = 2;
78
79 const char *s = str;
80 assert(s);
81
82#define CASE_NN(prb, c) \
83 do { \
84 if (!strncmp(prb, s, strlen(prb))) { \
85 ok = 1; \
86 s += strlen(prb); \
87 char *end_s; \
88 d.c = strtol(s, &end_s, 10); \
89 s += (end_s - s); \
90 if (d.c < 0) return FAIL; \
91 /* printf("@@@debug: %s: %d\n", prb, d. c); */ \
92 } \
93 } while (0)
94#define CASE_N(c) CASE_NN(#c, c)
95 while (*s) {
96 int ok = 0;
97 CASE_N(mb);
98 CASE_N(ic);
99 CASE_N(ih);
100 CASE_N(iw);
101 CASE_N(id);
102 CASE_N(oc);
103 if (*s == 'n') {
104 d.name = s + 1;
105 break;
106 }
107 if (*s == '_') ++s;
108 if (!ok) return FAIL;
109 }
110#undef CASE_NN
111#undef CASE_N
112
113 if (d.ic == 0 || d.oc == 0) return FAIL;
114
115 if (sanitize_desc(d.ndims, {d.id}, {d.ih}, {d.iw}, {1}) != OK) return FAIL;
116
117 *desc = d;
118
119 return OK;
120}
121
122std::ostream &operator<<(std::ostream &s, const desc_t &d) {
123 bool print_d = true, print_h = true, print_w = true;
124 print_dhw(print_d, print_h, print_w, d.ndims, {d.id}, {d.ih}, {d.iw});
125
126 if (canonical || d.mb != 2) s << "mb" << d.mb;
127
128 s << "ic" << d.ic;
129
130 if (print_d) s << "id" << d.id;
131 if (print_h) s << "ih" << d.ih;
132 if (print_w) s << "iw" << d.iw;
133
134 s << "oc" << d.oc;
135
136 if (!d.name.empty()) s << "n" << d.name;
137
138 return s;
139}
140
141std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
142 dump_global_params(s);
143 settings_t def;
144
145 if (canonical || prb.dir != def.dir[0]) s << "--dir=" << prb.dir << " ";
146 if (canonical || prb.cfg != def.cfg[0]) s << "--cfg=" << prb.cfg << " ";
147 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
148 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
149 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
150
151 s << prb.attr;
152 if (canonical || prb.ctx_init != def.ctx_init[0])
153 s << "--ctx-init=" << prb.ctx_init << " ";
154 if (canonical || prb.ctx_exe != def.ctx_exe[0])
155 s << "--ctx-exe=" << prb.ctx_exe << " ";
156
157 s << static_cast<const desc_t &>(prb);
158
159 return s;
160}
161
162dims_t desc_t::src_dims() const {
163 dims_t src_dims {mb, ic, id, ih, iw};
164 for (int d = 0; d < 5 - ndims; ++d) {
165 src_dims.erase(src_dims.begin() + 2);
166 }
167
168 return src_dims;
169}
170
171dims_t desc_t::wei_dims() const {
172 dims_t wei_dims {oc, ic, id, ih, iw};
173 for (int d = 0; d < 5 - ndims; ++d) {
174 wei_dims.erase(wei_dims.begin() + 2);
175 }
176
177 return wei_dims;
178}
179
180dims_t desc_t::bia_dims() const {
181 dims_t bia_dims {oc};
182 return bia_dims;
183}
184
185dims_t desc_t::dst_dims() const {
186 dims_t dst_dims {mb, oc};
187 return dst_dims;
188}
189
190int64_t desc_t::desc_nelems(int arg, int mask) const {
191 dims_t dims;
192 switch (arg) {
193 case DNNL_ARG_SRC: dims = src_dims(); break;
194 case DNNL_ARG_WEIGHTS: dims = wei_dims(); break;
195 case DNNL_ARG_DST: dims = dst_dims(); break;
196 default: assert(!"unsupported arg");
197 }
198
199 int64_t nelems = 1;
200 for (int d = 0; d < ndims; d++) {
201 nelems *= (mask & (1 << d)) ? dims[d] : 1;
202 }
203 return nelems;
204}
205
206} // namespace ip
207