1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <ctype.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | #include <string.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "dnnl_common.hpp" |
25 | #include "dnnl_debug.hpp" |
26 | |
27 | #include "matmul/matmul.hpp" |
28 | |
29 | namespace matmul { |
30 | |
31 | dnnl_data_type_t prb_t::get_dt(data_kind_t data_kind) const { |
32 | switch (data_kind) { |
33 | case SRC: return src_dt(); |
34 | case WEI: return wei_dt(); |
35 | case BIA: return bia_dt; |
36 | case DST: return dst_dt(); |
37 | default: assert(!"unexpected" ); return dnnl_data_type_undef; |
38 | } |
39 | } |
40 | |
41 | float *prb_t::generate_scales(int arg) const { |
42 | const auto &scales = attr.scales; |
43 | if (scales.is_def()) return nullptr; |
44 | |
45 | const auto &e = scales.get(arg); |
46 | if (e.policy == policy_t::COMMON) { |
47 | float *s = (float *)zmalloc(sizeof(float), 4); |
48 | SAFE_V(s != nullptr ? OK : FAIL); |
49 | s[0] = e.scale; |
50 | return s; |
51 | } |
52 | |
53 | assert(arg == DNNL_ARG_WEIGHTS); |
54 | assert(e.policy == policy_t::PER_OC); |
55 | |
56 | float *s = (float *)zmalloc(sizeof(float) * n, 64); |
57 | SAFE_V(s != nullptr ? OK : FAIL); |
58 | |
59 | const float K = 32; |
60 | /* scale in [1/K .. K], with starting point at e.scale */ |
61 | float s_val[2] = {e.scale, e.scale / 2}; |
62 | for (int64_t i = 0; i < n; ++i) { |
63 | int64_t si = i % 2; // 0 -> left, 1 -> right |
64 | s[i] = s_val[si]; |
65 | if (si == 0) { |
66 | s_val[si] /= 2.; |
67 | // turn around to become ~K |
68 | if (s_val[si] < 1. / K) s_val[si] *= K * K; |
69 | } else { |
70 | s_val[si] *= 2.; |
71 | // turn around to become ~K |
72 | if (s_val[si] > K) s_val[si] /= K * K; |
73 | } |
74 | } |
75 | return s; |
76 | } |
77 | |
78 | int32_t *prb_t::generate_zero_points( |
79 | int arg, const attr_t::zero_points_t &zero_points, int N) const { |
80 | if (zero_points.is_def(arg)) return nullptr; |
81 | |
82 | const auto &e = zero_points.get(arg); |
83 | if (e.policy == policy_t::COMMON) { |
84 | int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4); |
85 | SAFE_V(zp != nullptr ? OK : FAIL); |
86 | zp[0] = e.value; |
87 | return zp; |
88 | } |
89 | |
90 | assert(e.policy == policy_t::PER_DIM_1); |
91 | |
92 | int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64); |
93 | SAFE_V(zp != nullptr ? OK : FAIL); |
94 | |
95 | for (int i = 0; i < N; ++i) |
96 | zp[i] = e.value + i % 3; |
97 | return zp; |
98 | } |
99 | |
100 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
101 | dump_global_params(s); |
102 | settings_t def; |
103 | |
104 | bool has_default_dts = true; |
105 | for (const auto &i_dt : prb.dt) |
106 | has_default_dts = has_default_dts && i_dt == dnnl_f32; |
107 | |
108 | if (canonical || !has_default_dts) s << "--dt=" << prb.dt << " " ; |
109 | if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " " ; |
110 | if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " " ; |
111 | if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " " ; |
112 | if (canonical || prb.strides != def.strides[0]) |
113 | s << "--strides=" << vdims2str(prb.strides) << " " ; |
114 | |
115 | if (canonical || prb.src_runtime_dim_mask().any() |
116 | || prb.weights_runtime_dim_mask().any()) |
117 | s << "--runtime_dims_masks=" << prb.src_runtime_dim_mask().to_ulong() |
118 | << ":" << prb.weights_runtime_dim_mask().to_ulong() << " " ; |
119 | |
120 | if (canonical || prb.bia_dt != def.bia_dt[0]) { |
121 | s << "--bia_dt=" << prb.bia_dt << " " ; |
122 | |
123 | if (canonical || prb.bia_mask != def.bia_mask[0]) |
124 | s << "--bia_mask=" << prb.bia_mask << " " ; |
125 | } |
126 | |
127 | s << prb.attr; |
128 | if (canonical || prb.ctx_init != def.ctx_init[0]) |
129 | s << "--ctx-init=" << prb.ctx_init << " " ; |
130 | if (canonical || prb.ctx_exe != def.ctx_exe[0]) |
131 | s << "--ctx-exe=" << prb.ctx_exe << " " ; |
132 | |
133 | s << static_cast<const prb_vdims_t &>(prb); |
134 | |
135 | return s; |
136 | } |
137 | |
138 | } // namespace matmul |
139 | |