1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <ctype.h> |
18 | #include <stdio.h> |
19 | #include <stdlib.h> |
20 | #include <string.h> |
21 | |
22 | #include "oneapi/dnnl/dnnl.h" |
23 | |
24 | #include "dnnl_common.hpp" |
25 | #include "dnnl_debug.hpp" |
26 | |
27 | #include "brgemm/brgemm.hpp" |
28 | |
29 | namespace brgemm { |
30 | |
31 | dnnl_data_type_t prb_t::get_dt(data_kind_t data_kind) const { |
32 | switch (data_kind) { |
33 | case SRC: return src_dt(); |
34 | case WEI: return wei_dt(); |
35 | case BIA: return bia_dt; |
36 | case DST: return dst_dt(); |
37 | case ACC: return acc_dt(); |
38 | default: assert(!"unexpected" ); return dnnl_data_type_undef; |
39 | } |
40 | } |
41 | |
42 | void prb_t::generate_oscales() { |
43 | // Brgemm takes single pointer oscale, but relies on a combination of arg |
44 | // scales attributes. This helps to reuse attributes from primitives, but |
45 | // requires them to pre-compute oscale = src_scale * wei_scale[:] |
46 | const auto &attr_scales = attr.scales; |
47 | |
48 | const auto &src_sc = attr_scales.get(DNNL_ARG_SRC); |
49 | float src_scale_val = 1.0f; |
50 | if (!src_sc.is_def()) { |
51 | assert(src_sc.policy == policy_t::COMMON); |
52 | src_scale_val = src_sc.scale; |
53 | } |
54 | |
55 | const auto &wei_sc = attr_scales.get(DNNL_ARG_WEIGHTS); |
56 | |
57 | if (wei_sc.policy == policy_t::COMMON) { |
58 | scales = (float *)zmalloc(sizeof(float), 4); |
59 | SAFE_V(scales != nullptr ? OK : FAIL); |
60 | scales[0] = wei_sc.scale; |
61 | if (!src_sc.is_def()) { scales[0] *= src_scale_val; } |
62 | return; |
63 | } |
64 | |
65 | assert(wei_sc.policy == policy_t::PER_OC); |
66 | |
67 | scales = (float *)zmalloc(sizeof(float) * n, 64); |
68 | SAFE_V(scales != nullptr ? OK : FAIL); |
69 | |
70 | const float K = 32; |
71 | /* scale in [1/K .. K], with starting point at wei_sc.scale */ |
72 | float s[2] = {wei_sc.scale, wei_sc.scale / 2}; |
73 | for (int64_t i = 0; i < n; ++i) { |
74 | int64_t si = i % 2; // 0 -> left, 1 -> right |
75 | scales[i] = s[si] * src_scale_val; |
76 | if (si == 0) { |
77 | s[si] /= 2.; |
78 | if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K |
79 | } else { |
80 | s[si] *= 2.; |
81 | if (s[si] > K) s[si] /= K * K; // turn around to become ~K |
82 | } |
83 | } |
84 | } |
85 | |
86 | int32_t *prb_t::generate_zero_points( |
87 | int arg, const attr_t::zero_points_t &zero_points, int N) { |
88 | if (zero_points.is_def(arg)) return nullptr; |
89 | |
90 | const auto &e = zero_points.get(arg); |
91 | if (e.policy == policy_t::COMMON) { |
92 | int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4); |
93 | SAFE_V(zp != nullptr ? OK : FAIL); |
94 | zp[0] = e.value; |
95 | return zp; |
96 | } |
97 | |
98 | assert(e.policy == policy_t::PER_DIM_1); |
99 | |
100 | int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64); |
101 | SAFE_V(zp != nullptr ? OK : FAIL); |
102 | |
103 | for (int i = 0; i < N; ++i) |
104 | zp[i] = e.value + i % 3; |
105 | return zp; |
106 | } |
107 | |
108 | std::ostream &operator<<(std::ostream &s, const prb_t &prb) { |
109 | dump_global_params(s); |
110 | settings_t def; |
111 | |
112 | bool has_default_dts = true; |
113 | for (const auto &i_dt : prb.dt) |
114 | has_default_dts = has_default_dts && i_dt == dnnl_f32; |
115 | |
116 | if (canonical || !has_default_dts) s << "--dt=" << prb.dt << " " ; |
117 | if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " " ; |
118 | if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " " ; |
119 | if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " " ; |
120 | if (canonical || prb.ld != def.ld[0]) { |
121 | s << "--ld=" ; |
122 | if (prb.ld[0] != 0) s << prb.ld[0]; |
123 | s << ":" ; |
124 | if (prb.ld[1] != 0) s << prb.ld[1]; |
125 | s << ":" ; |
126 | if (prb.ld[2] != 0) s << prb.ld[2]; |
127 | s << " " ; |
128 | } |
129 | |
130 | if (canonical || prb.bia_dt != def.bia_dt[0]) |
131 | s << "--bia_dt=" << prb.bia_dt << " " ; |
132 | |
133 | if (canonical || prb.alpha != def.alpha[0]) |
134 | s << "--alpha=" << prb.alpha << " " ; |
135 | if (canonical || prb.beta != def.beta[0]) s << "--beta=" << prb.beta << " " ; |
136 | if (canonical || prb.batch_size != def.batch_size[0]) |
137 | s << "--bs=" << prb.batch_size << " " ; |
138 | if (canonical || prb.brgemm_attr != def.brgemm_attr[0]) |
139 | s << "--brgemm-attr=" << prb.brgemm_attr << " " ; |
140 | |
141 | s << prb.attr; |
142 | s << static_cast<const prb_vdims_t &>(prb); |
143 | |
144 | return s; |
145 | } |
146 | |
147 | } // namespace brgemm |
148 | |