1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18
19#include "cpu/x64/matmul/brgemm_matmul_reorders.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init(
27 engine_t *engine, engine_t *src_engine, engine_t *dst_engine) {
28 using namespace status;
29 using namespace format_tag;
30
31 status_t status = cpu_reorder_pd_t::init(engine, src_engine, dst_engine);
32 if (status != success) return status;
33
34 const memory_desc_wrapper id(src_md_), od(dst_md_);
35 const int ndims = id.ndims();
36
37 const auto type_i = id.data_type();
38 const auto type_o = od.data_type();
39 // TODO: enable support for type_i != type_o cases
40 const bool dt_ok = true && type_i == type_o
41 && utils::one_of(type_o, data_type::s8, data_type::bf16,
42 data_type::f16, data_type::f32);
43 const bool is_f16 = utils::one_of(data_type::f16, type_i, type_o);
44 const bool is_s8s8 = type_i == data_type::s8 && type_o == data_type::s8;
45 const bool has_adj_scale
46 = od.extra().flags & memory_extra_flags::scale_adjust;
47 const bool args_ok = true && dt_ok && id.is_dense()
48 && utils::one_of(ndims, 2, 3)
49 && IMPLICATION(is_f16, mayiuse(avx512_core_fp16))
50 && IMPLICATION(!is_f16, mayiuse(avx512_core))
51 && IMPLICATION(is_s8s8, mayiuse(avx512_core_vnni)) && !has_adj_scale
52 && attr()->has_default_values() && od.is_blocking_desc()
53 && !od.has_runtime_dims_or_strides() && !od.has_zero_dim();
54 if (!args_ok) return invalid_arguments;
55
56 const auto &dims = id.dims();
57 // TODO: enable for itag = {ba, acb}
58 format_tag_t itag = id.matches_one_of_tag(ab, abc);
59 format_tag_t otag = format_tag::undef;
60
61 const auto vnni_granularity = data_type_vnni_granularity(type_o);
62 switch (vnni_granularity) {
63 case 4:
64 otag = od.matches_one_of_tag(aCB16b64c4b, BA16a64b4a, aCB16b48c4b,
65 BA16a48b4a, aCB16b32c4b, BA16a32b4a, aCB16b16c4b,
66 BA16a16b4a);
67 break;
68 case 2:
69 otag = od.matches_one_of_tag(aCB16b64c2b, BA16a64b2a, aCB16b48c2b,
70 BA16a48b2a, aCB16b32c2b, BA16a32b2a, aCB16b16c2b,
71 BA16a16b2a);
72 break;
73 case 1:
74 otag = od.matches_one_of_tag(aCB16b64c, BA16a64b, aCB16b48c,
75 BA16a48b, aCB16b32c, BA16a32b, aCB16b16c, BA16a16b);
76 break;
77 default: otag = format_tag::undef;
78 }
79
80 if (utils::one_of(format_tag::undef, itag, otag)) return invalid_arguments;
81
82 // initialize all required fields to generate copy_b kernel
83 matmul_conf_for_reorder_.wei_tag = itag;
84 matmul_conf_for_reorder_.batch = ndims > 2 ? dims[ndims - 3] : 1;
85 matmul_conf_for_reorder_.K = dims[ndims - 2];
86 matmul_conf_for_reorder_.N = dims[ndims - 1];
87 matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk
88 = matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag);
89 matmul_conf_for_reorder_.N_tail
90 = matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk;
91 matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity;
92 matmul_conf_for_reorder_.K_tail
93 = matmul_conf_for_reorder_.K % matmul_conf_for_reorder_.K_blk;
94 matmul_conf_for_reorder_.src_dt = matmul_conf_for_reorder_.wei_dt = type_o;
95 matmul_conf_for_reorder_.a_dt_sz = matmul_conf_for_reorder_.tr_a_dt_sz
96 = types::data_type_size(matmul_conf_for_reorder_.src_dt);
97 matmul_conf_for_reorder_.b_dt_sz = matmul_conf_for_reorder_.tr_b_dt_sz
98 = types::data_type_size(matmul_conf_for_reorder_.wei_dt);
99 matmul_conf_for_reorder_.s8s8_comp_b_str = utils::rnd_up(
100 matmul_conf_for_reorder_.N, matmul_conf_for_reorder_.wei_n_blk);
101 matmul_conf_for_reorder_.s8s8_comp_n_str
102 = matmul_conf_for_reorder_.wei_n_blk;
103 matmul_conf_for_reorder_.s8s8_compensation_required
104 = od.extra().flags & memory_extra_flags::compensation_conv_s8s8;
105 const bool req_asymmetric_comp = od.extra().flags
106 & memory_extra_flags::compensation_conv_asymmetric_src;
107 matmul_conf_for_reorder_.src_zp_type = req_asymmetric_comp
108 ? brgemm_broadcast_t::per_tensor
109 : brgemm_broadcast_t::none;
110 matmul_conf_for_reorder_.has_zero_point_a
111 = matmul_conf_for_reorder_.src_zp_type != brgemm_broadcast_t::none;
112 matmul_conf_for_reorder_.isa = is_f16 ? avx512_core_fp16 : avx512_core;
113
114 auto mask_ok = [&](bool check, int mask) {
115 return IMPLICATION(
116 check, mask == (1 << ndims) - 1 - (1 << (ndims - 2)));
117 };
118
119 const bool comp_masks_ok = true
120 && mask_ok(matmul_conf_for_reorder_.s8s8_compensation_required,
121 od.extra().compensation_mask)
122 && mask_ok(req_asymmetric_comp, od.extra().asymm_compensation_mask);
123 if (!comp_masks_ok) return invalid_arguments;
124
125 init_scratchpad();
126
127 return status::success;
128}
129
130status_t brgemm_matmul_matrix_B_reorder_t::pd_t::create(
131 reorder_pd_t **reorder_pd, engine_t *engine,
132 const primitive_attr_t *attr, engine_t *src_engine,
133 const memory_desc_t *src_md, engine_t *dst_engine,
134 const memory_desc_t *dst_md) {
135 using namespace status;
136
137 auto _pd = new pd_t(
138 attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md);
139 if (_pd == nullptr) return out_of_memory;
140 if (_pd->init(engine, src_engine, dst_engine) != success) {
141 delete _pd;
142 return unimplemented;
143 }
144
145 _pd->init_scratchpad_md();
146 return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
147}
148
149status_t brgemm_matmul_matrix_B_reorder_t::execute_body(
150 const exec_ctx_t &ctx) const {
151 using namespace utils;
152
153 const auto src = CTX_IN_MEM(const char *, DNNL_ARG_FROM);
154 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_TO);
155 const memory_desc_wrapper &src_d = pd()->src_md();
156 const memory_desc_wrapper &dst_d = pd()->dst_md();
157 const auto sdt_sz = types::data_type_size(src_d.data_type());
158 const auto type_o = dst_d.data_type();
159 const auto ddt_sz = types::data_type_size(type_o);
160
161 const auto &kernel_conf = pd()->matmul_conf_for_reorder_;
162 const size_t comp_offset_bytes
163 = dst_d.size() - dst_d.additional_buffer_size();
164 const size_t s8s8_comp_size_bytes = kernel_conf.s8s8_compensation_required
165 ? dst_d.additional_buffer_size(
166 memory_extra_flags::compensation_conv_s8s8)
167 : 0;
168 const size_t zp_comp_offset_bytes
169 = comp_offset_bytes + s8s8_comp_size_bytes;
170 int32_t *cp = kernel_conf.s8s8_compensation_required
171 ? reinterpret_cast<int32_t *>(dst + comp_offset_bytes)
172 : nullptr;
173 int32_t *zp = kernel_conf.has_zero_point_a
174 ? reinterpret_cast<int32_t *>(dst + zp_comp_offset_bytes)
175 : nullptr;
176
177 const int ndims = src_d.ndims();
178#define get_blk_off(md, dt_sz, batch, d0, d1) \
179 (ndims == 3 ? (dt_sz) * (md).blk_off((batch), (d0), (d1)) \
180 : (dt_sz) * (md).blk_off((d0), (d1)))
181
182 parallel_nd(kernel_conf.batch, div_up(kernel_conf.N, kernel_conf.N_blk),
183 [&](dim_t batch, dim_t n_blk_idx) {
184 const auto n = n_blk_idx * kernel_conf.N_blk;
185 const bool is_N_tail = (kernel_conf.N - n < kernel_conf.N_blk);
186 auto ker_exec_ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t();
187 ker_exec_ctx.current_N_blk
188 = is_N_tail ? kernel_conf.N_tail : kernel_conf.N_blk;
189
190 const auto comp_offset = batch * kernel_conf.s8s8_comp_b_str
191 + n_blk_idx * kernel_conf.s8s8_comp_n_str;
192
193 ker_exec_ctx.zp_a_compensation_ptr
194 = kernel_conf.has_zero_point_a
195 ? (void *)&zp[comp_offset]
196 : nullptr;
197 ker_exec_ctx.compensation_ptr
198 = kernel_conf.s8s8_compensation_required
199 ? (void *)&cp[comp_offset]
200 : nullptr;
201
202 // required to compute zp compensation
203 int tmp_neg_a_zp_val = -1;
204 ker_exec_ctx.zp_a_neg_value_ptr = &tmp_neg_a_zp_val;
205
206 int k_blk_idx = 0;
207 for (; k_blk_idx < kernel_conf.K / kernel_conf.K_blk;
208 k_blk_idx++) {
209 const auto k = k_blk_idx * kernel_conf.K_blk;
210 ker_exec_ctx.src = (void *)&src[get_blk_off(
211 src_d, sdt_sz, batch, k, n)];
212 ker_exec_ctx.tr_src = (void *)&dst[get_blk_off(
213 dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx)];
214 ker_exec_ctx.current_K_start = k;
215 ker_exec_ctx.current_K_iters = kernel_conf.K_blk;
216 (*kernel_)(&ker_exec_ctx);
217 }
218 if (kernel_conf.K_tail > 0) {
219 const auto k = k_blk_idx * kernel_conf.K_blk;
220 ker_exec_ctx.src = (void *)&src[get_blk_off(
221 src_d, sdt_sz, batch, k, n)];
222 const auto dst_offset = get_blk_off(
223 dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx);
224 ker_exec_ctx.tr_src = (void *)&dst[dst_offset];
225 ker_exec_ctx.current_K_start = k;
226 ker_exec_ctx.current_K_iters = kernel_conf.K_tail;
227 (*kernel_)(&ker_exec_ctx);
228 const auto vnni_granularity
229 = data_type_vnni_granularity(type_o);
230 const auto dst_zero_out_offset
231 = rnd_up(kernel_conf.K_tail, vnni_granularity)
232 * kernel_conf.N_blk * ddt_sz;
233 const auto elems_to_zero
234 = rnd_dn(kernel_conf.K_blk - kernel_conf.K_tail,
235 vnni_granularity)
236 * kernel_conf.N_blk * ddt_sz;
237 array_set(&dst[dst_offset + dst_zero_out_offset], 0,
238 elems_to_zero);
239 }
240 });
241
242#undef get_blk_off
243
244 return status::success;
245}
246
247} // namespace x64
248} // namespace cpu
249} // namespace impl
250} // namespace dnnl
251