1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <ctype.h>
18#include <stdio.h>
19#include <stdlib.h>
20#include <string.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "dnnl_common.hpp"
25#include "dnnl_debug.hpp"
26
27#include "matmul/matmul.hpp"
28
29namespace matmul {
30
31dnnl_data_type_t prb_t::get_dt(data_kind_t data_kind) const {
32 switch (data_kind) {
33 case SRC: return src_dt();
34 case WEI: return wei_dt();
35 case BIA: return bia_dt;
36 case DST: return dst_dt();
37 default: assert(!"unexpected"); return dnnl_data_type_undef;
38 }
39}
40
41float *prb_t::generate_scales(int arg) const {
42 const auto &scales = attr.scales;
43 if (scales.is_def()) return nullptr;
44
45 const auto &e = scales.get(arg);
46 if (e.policy == policy_t::COMMON) {
47 float *s = (float *)zmalloc(sizeof(float), 4);
48 SAFE_V(s != nullptr ? OK : FAIL);
49 s[0] = e.scale;
50 return s;
51 }
52
53 assert(arg == DNNL_ARG_WEIGHTS);
54 assert(e.policy == policy_t::PER_OC);
55
56 float *s = (float *)zmalloc(sizeof(float) * n, 64);
57 SAFE_V(s != nullptr ? OK : FAIL);
58
59 const float K = 32;
60 /* scale in [1/K .. K], with starting point at e.scale */
61 float s_val[2] = {e.scale, e.scale / 2};
62 for (int64_t i = 0; i < n; ++i) {
63 int64_t si = i % 2; // 0 -> left, 1 -> right
64 s[i] = s_val[si];
65 if (si == 0) {
66 s_val[si] /= 2.;
67 // turn around to become ~K
68 if (s_val[si] < 1. / K) s_val[si] *= K * K;
69 } else {
70 s_val[si] *= 2.;
71 // turn around to become ~K
72 if (s_val[si] > K) s_val[si] /= K * K;
73 }
74 }
75 return s;
76}
77
78int32_t *prb_t::generate_zero_points(
79 int arg, const attr_t::zero_points_t &zero_points, int N) const {
80 if (zero_points.is_def(arg)) return nullptr;
81
82 const auto &e = zero_points.get(arg);
83 if (e.policy == policy_t::COMMON) {
84 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4);
85 SAFE_V(zp != nullptr ? OK : FAIL);
86 zp[0] = e.value;
87 return zp;
88 }
89
90 assert(e.policy == policy_t::PER_DIM_1);
91
92 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64);
93 SAFE_V(zp != nullptr ? OK : FAIL);
94
95 for (int i = 0; i < N; ++i)
96 zp[i] = e.value + i % 3;
97 return zp;
98}
99
100std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
101 dump_global_params(s);
102 settings_t def;
103
104 bool has_default_dts = true;
105 for (const auto &i_dt : prb.dt)
106 has_default_dts = has_default_dts && i_dt == dnnl_f32;
107
108 if (canonical || !has_default_dts) s << "--dt=" << prb.dt << " ";
109 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
110 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
111 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
112 if (canonical || prb.strides != def.strides[0])
113 s << "--strides=" << vdims2str(prb.strides) << " ";
114
115 if (canonical || prb.src_runtime_dim_mask().any()
116 || prb.weights_runtime_dim_mask().any())
117 s << "--runtime_dims_masks=" << prb.src_runtime_dim_mask().to_ulong()
118 << ":" << prb.weights_runtime_dim_mask().to_ulong() << " ";
119
120 if (canonical || prb.bia_dt != def.bia_dt[0]) {
121 s << "--bia_dt=" << prb.bia_dt << " ";
122
123 if (canonical || prb.bia_mask != def.bia_mask[0])
124 s << "--bia_mask=" << prb.bia_mask << " ";
125 }
126
127 s << prb.attr;
128 if (canonical || prb.ctx_init != def.ctx_init[0])
129 s << "--ctx-init=" << prb.ctx_init << " ";
130 if (canonical || prb.ctx_exe != def.ctx_exe[0])
131 s << "--ctx-exe=" << prb.ctx_exe << " ";
132
133 s << static_cast<const prb_vdims_t &>(prb);
134
135 return s;
136}
137
138} // namespace matmul
139