1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <ctype.h>
18#include <stdio.h>
19#include <stdlib.h>
20#include <string.h>
21
22#include "oneapi/dnnl/dnnl.h"
23
24#include "dnnl_common.hpp"
25#include "dnnl_debug.hpp"
26
27#include "brgemm/brgemm.hpp"
28
29namespace brgemm {
30
31dnnl_data_type_t prb_t::get_dt(data_kind_t data_kind) const {
32 switch (data_kind) {
33 case SRC: return src_dt();
34 case WEI: return wei_dt();
35 case BIA: return bia_dt;
36 case DST: return dst_dt();
37 case ACC: return acc_dt();
38 default: assert(!"unexpected"); return dnnl_data_type_undef;
39 }
40}
41
42void prb_t::generate_oscales() {
43 // Brgemm takes single pointer oscale, but relies on a combination of arg
44 // scales attributes. This helps to reuse attributes from primitives, but
45 // requires them to pre-compute oscale = src_scale * wei_scale[:]
46 const auto &attr_scales = attr.scales;
47
48 const auto &src_sc = attr_scales.get(DNNL_ARG_SRC);
49 float src_scale_val = 1.0f;
50 if (!src_sc.is_def()) {
51 assert(src_sc.policy == policy_t::COMMON);
52 src_scale_val = src_sc.scale;
53 }
54
55 const auto &wei_sc = attr_scales.get(DNNL_ARG_WEIGHTS);
56
57 if (wei_sc.policy == policy_t::COMMON) {
58 scales = (float *)zmalloc(sizeof(float), 4);
59 SAFE_V(scales != nullptr ? OK : FAIL);
60 scales[0] = wei_sc.scale;
61 if (!src_sc.is_def()) { scales[0] *= src_scale_val; }
62 return;
63 }
64
65 assert(wei_sc.policy == policy_t::PER_OC);
66
67 scales = (float *)zmalloc(sizeof(float) * n, 64);
68 SAFE_V(scales != nullptr ? OK : FAIL);
69
70 const float K = 32;
71 /* scale in [1/K .. K], with starting point at wei_sc.scale */
72 float s[2] = {wei_sc.scale, wei_sc.scale / 2};
73 for (int64_t i = 0; i < n; ++i) {
74 int64_t si = i % 2; // 0 -> left, 1 -> right
75 scales[i] = s[si] * src_scale_val;
76 if (si == 0) {
77 s[si] /= 2.;
78 if (s[si] < 1. / K) s[si] *= K * K; // turn around to become ~K
79 } else {
80 s[si] *= 2.;
81 if (s[si] > K) s[si] /= K * K; // turn around to become ~K
82 }
83 }
84}
85
86int32_t *prb_t::generate_zero_points(
87 int arg, const attr_t::zero_points_t &zero_points, int N) {
88 if (zero_points.is_def(arg)) return nullptr;
89
90 const auto &e = zero_points.get(arg);
91 if (e.policy == policy_t::COMMON) {
92 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t), 4);
93 SAFE_V(zp != nullptr ? OK : FAIL);
94 zp[0] = e.value;
95 return zp;
96 }
97
98 assert(e.policy == policy_t::PER_DIM_1);
99
100 int32_t *zp = (int32_t *)zmalloc(sizeof(int32_t) * N, 64);
101 SAFE_V(zp != nullptr ? OK : FAIL);
102
103 for (int i = 0; i < N; ++i)
104 zp[i] = e.value + i % 3;
105 return zp;
106}
107
108std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
109 dump_global_params(s);
110 settings_t def;
111
112 bool has_default_dts = true;
113 for (const auto &i_dt : prb.dt)
114 has_default_dts = has_default_dts && i_dt == dnnl_f32;
115
116 if (canonical || !has_default_dts) s << "--dt=" << prb.dt << " ";
117 if (canonical || prb.stag != def.stag[0]) s << "--stag=" << prb.stag << " ";
118 if (canonical || prb.wtag != def.wtag[0]) s << "--wtag=" << prb.wtag << " ";
119 if (canonical || prb.dtag != def.dtag[0]) s << "--dtag=" << prb.dtag << " ";
120 if (canonical || prb.ld != def.ld[0]) {
121 s << "--ld=";
122 if (prb.ld[0] != 0) s << prb.ld[0];
123 s << ":";
124 if (prb.ld[1] != 0) s << prb.ld[1];
125 s << ":";
126 if (prb.ld[2] != 0) s << prb.ld[2];
127 s << " ";
128 }
129
130 if (canonical || prb.bia_dt != def.bia_dt[0])
131 s << "--bia_dt=" << prb.bia_dt << " ";
132
133 if (canonical || prb.alpha != def.alpha[0])
134 s << "--alpha=" << prb.alpha << " ";
135 if (canonical || prb.beta != def.beta[0]) s << "--beta=" << prb.beta << " ";
136 if (canonical || prb.batch_size != def.batch_size[0])
137 s << "--bs=" << prb.batch_size << " ";
138 if (canonical || prb.brgemm_attr != def.brgemm_attr[0])
139 s << "--brgemm-attr=" << prb.brgemm_attr << " ";
140
141 s << prb.attr;
142 s << static_cast<const prb_vdims_t &>(prb);
143
144 return s;
145}
146
147} // namespace brgemm
148