1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | |
19 | #include "cpu/x64/matmul/brgemm_matmul_reorders.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( |
27 | engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { |
28 | using namespace status; |
29 | using namespace format_tag; |
30 | |
31 | status_t status = cpu_reorder_pd_t::init(engine, src_engine, dst_engine); |
32 | if (status != success) return status; |
33 | |
34 | const memory_desc_wrapper id(src_md_), od(dst_md_); |
35 | const int ndims = id.ndims(); |
36 | |
37 | const auto type_i = id.data_type(); |
38 | const auto type_o = od.data_type(); |
39 | // TODO: enable support for type_i != type_o cases |
40 | const bool dt_ok = true && type_i == type_o |
41 | && utils::one_of(type_o, data_type::s8, data_type::bf16, |
42 | data_type::f16, data_type::f32); |
43 | const bool is_f16 = utils::one_of(data_type::f16, type_i, type_o); |
44 | const bool is_s8s8 = type_i == data_type::s8 && type_o == data_type::s8; |
45 | const bool has_adj_scale |
46 | = od.extra().flags & memory_extra_flags::scale_adjust; |
47 | const bool args_ok = true && dt_ok && id.is_dense() |
48 | && utils::one_of(ndims, 2, 3) |
49 | && IMPLICATION(is_f16, mayiuse(avx512_core_fp16)) |
50 | && IMPLICATION(!is_f16, mayiuse(avx512_core)) |
51 | && IMPLICATION(is_s8s8, mayiuse(avx512_core_vnni)) && !has_adj_scale |
52 | && attr()->has_default_values() && od.is_blocking_desc() |
53 | && !od.has_runtime_dims_or_strides() && !od.has_zero_dim(); |
54 | if (!args_ok) return invalid_arguments; |
55 | |
56 | const auto &dims = id.dims(); |
57 | // TODO: enable for itag = {ba, acb} |
58 | format_tag_t itag = id.matches_one_of_tag(ab, abc); |
59 | format_tag_t otag = format_tag::undef; |
60 | |
61 | const auto vnni_granularity = data_type_vnni_granularity(type_o); |
62 | switch (vnni_granularity) { |
63 | case 4: |
64 | otag = od.matches_one_of_tag(aCB16b64c4b, BA16a64b4a, aCB16b48c4b, |
65 | BA16a48b4a, aCB16b32c4b, BA16a32b4a, aCB16b16c4b, |
66 | BA16a16b4a); |
67 | break; |
68 | case 2: |
69 | otag = od.matches_one_of_tag(aCB16b64c2b, BA16a64b2a, aCB16b48c2b, |
70 | BA16a48b2a, aCB16b32c2b, BA16a32b2a, aCB16b16c2b, |
71 | BA16a16b2a); |
72 | break; |
73 | case 1: |
74 | otag = od.matches_one_of_tag(aCB16b64c, BA16a64b, aCB16b48c, |
75 | BA16a48b, aCB16b32c, BA16a32b, aCB16b16c, BA16a16b); |
76 | break; |
77 | default: otag = format_tag::undef; |
78 | } |
79 | |
80 | if (utils::one_of(format_tag::undef, itag, otag)) return invalid_arguments; |
81 | |
82 | // initialize all required fields to generate copy_b kernel |
83 | matmul_conf_for_reorder_.wei_tag = itag; |
84 | matmul_conf_for_reorder_.batch = ndims > 2 ? dims[ndims - 3] : 1; |
85 | matmul_conf_for_reorder_.K = dims[ndims - 2]; |
86 | matmul_conf_for_reorder_.N = dims[ndims - 1]; |
87 | matmul_conf_for_reorder_.wei_n_blk = matmul_conf_for_reorder_.N_blk |
88 | = matmul_conf_for_reorder_.LDB = matmul::get_default_n_block(otag); |
89 | matmul_conf_for_reorder_.N_tail |
90 | = matmul_conf_for_reorder_.N % matmul_conf_for_reorder_.N_blk; |
91 | matmul_conf_for_reorder_.K_blk = 16 * vnni_granularity; |
92 | matmul_conf_for_reorder_.K_tail |
93 | = matmul_conf_for_reorder_.K % matmul_conf_for_reorder_.K_blk; |
94 | matmul_conf_for_reorder_.src_dt = matmul_conf_for_reorder_.wei_dt = type_o; |
95 | matmul_conf_for_reorder_.a_dt_sz = matmul_conf_for_reorder_.tr_a_dt_sz |
96 | = types::data_type_size(matmul_conf_for_reorder_.src_dt); |
97 | matmul_conf_for_reorder_.b_dt_sz = matmul_conf_for_reorder_.tr_b_dt_sz |
98 | = types::data_type_size(matmul_conf_for_reorder_.wei_dt); |
99 | matmul_conf_for_reorder_.s8s8_comp_b_str = utils::rnd_up( |
100 | matmul_conf_for_reorder_.N, matmul_conf_for_reorder_.wei_n_blk); |
101 | matmul_conf_for_reorder_.s8s8_comp_n_str |
102 | = matmul_conf_for_reorder_.wei_n_blk; |
103 | matmul_conf_for_reorder_.s8s8_compensation_required |
104 | = od.extra().flags & memory_extra_flags::compensation_conv_s8s8; |
105 | const bool req_asymmetric_comp = od.extra().flags |
106 | & memory_extra_flags::compensation_conv_asymmetric_src; |
107 | matmul_conf_for_reorder_.src_zp_type = req_asymmetric_comp |
108 | ? brgemm_broadcast_t::per_tensor |
109 | : brgemm_broadcast_t::none; |
110 | matmul_conf_for_reorder_.has_zero_point_a |
111 | = matmul_conf_for_reorder_.src_zp_type != brgemm_broadcast_t::none; |
112 | matmul_conf_for_reorder_.isa = is_f16 ? avx512_core_fp16 : avx512_core; |
113 | |
114 | auto mask_ok = [&](bool check, int mask) { |
115 | return IMPLICATION( |
116 | check, mask == (1 << ndims) - 1 - (1 << (ndims - 2))); |
117 | }; |
118 | |
119 | const bool comp_masks_ok = true |
120 | && mask_ok(matmul_conf_for_reorder_.s8s8_compensation_required, |
121 | od.extra().compensation_mask) |
122 | && mask_ok(req_asymmetric_comp, od.extra().asymm_compensation_mask); |
123 | if (!comp_masks_ok) return invalid_arguments; |
124 | |
125 | init_scratchpad(); |
126 | |
127 | return status::success; |
128 | } |
129 | |
130 | status_t brgemm_matmul_matrix_B_reorder_t::pd_t::create( |
131 | reorder_pd_t **reorder_pd, engine_t *engine, |
132 | const primitive_attr_t *attr, engine_t *src_engine, |
133 | const memory_desc_t *src_md, engine_t *dst_engine, |
134 | const memory_desc_t *dst_md) { |
135 | using namespace status; |
136 | |
137 | auto _pd = new pd_t( |
138 | attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); |
139 | if (_pd == nullptr) return out_of_memory; |
140 | if (_pd->init(engine, src_engine, dst_engine) != success) { |
141 | delete _pd; |
142 | return unimplemented; |
143 | } |
144 | |
145 | _pd->init_scratchpad_md(); |
146 | return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd); |
147 | } |
148 | |
149 | status_t brgemm_matmul_matrix_B_reorder_t::execute_body( |
150 | const exec_ctx_t &ctx) const { |
151 | using namespace utils; |
152 | |
153 | const auto src = CTX_IN_MEM(const char *, DNNL_ARG_FROM); |
154 | auto dst = CTX_OUT_MEM(char *, DNNL_ARG_TO); |
155 | const memory_desc_wrapper &src_d = pd()->src_md(); |
156 | const memory_desc_wrapper &dst_d = pd()->dst_md(); |
157 | const auto sdt_sz = types::data_type_size(src_d.data_type()); |
158 | const auto type_o = dst_d.data_type(); |
159 | const auto ddt_sz = types::data_type_size(type_o); |
160 | |
161 | const auto &kernel_conf = pd()->matmul_conf_for_reorder_; |
162 | const size_t comp_offset_bytes |
163 | = dst_d.size() - dst_d.additional_buffer_size(); |
164 | const size_t s8s8_comp_size_bytes = kernel_conf.s8s8_compensation_required |
165 | ? dst_d.additional_buffer_size( |
166 | memory_extra_flags::compensation_conv_s8s8) |
167 | : 0; |
168 | const size_t zp_comp_offset_bytes |
169 | = comp_offset_bytes + s8s8_comp_size_bytes; |
170 | int32_t *cp = kernel_conf.s8s8_compensation_required |
171 | ? reinterpret_cast<int32_t *>(dst + comp_offset_bytes) |
172 | : nullptr; |
173 | int32_t *zp = kernel_conf.has_zero_point_a |
174 | ? reinterpret_cast<int32_t *>(dst + zp_comp_offset_bytes) |
175 | : nullptr; |
176 | |
177 | const int ndims = src_d.ndims(); |
178 | #define get_blk_off(md, dt_sz, batch, d0, d1) \ |
179 | (ndims == 3 ? (dt_sz) * (md).blk_off((batch), (d0), (d1)) \ |
180 | : (dt_sz) * (md).blk_off((d0), (d1))) |
181 | |
182 | parallel_nd(kernel_conf.batch, div_up(kernel_conf.N, kernel_conf.N_blk), |
183 | [&](dim_t batch, dim_t n_blk_idx) { |
184 | const auto n = n_blk_idx * kernel_conf.N_blk; |
185 | const bool is_N_tail = (kernel_conf.N - n < kernel_conf.N_blk); |
186 | auto ker_exec_ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); |
187 | ker_exec_ctx.current_N_blk |
188 | = is_N_tail ? kernel_conf.N_tail : kernel_conf.N_blk; |
189 | |
190 | const auto comp_offset = batch * kernel_conf.s8s8_comp_b_str |
191 | + n_blk_idx * kernel_conf.s8s8_comp_n_str; |
192 | |
193 | ker_exec_ctx.zp_a_compensation_ptr |
194 | = kernel_conf.has_zero_point_a |
195 | ? (void *)&zp[comp_offset] |
196 | : nullptr; |
197 | ker_exec_ctx.compensation_ptr |
198 | = kernel_conf.s8s8_compensation_required |
199 | ? (void *)&cp[comp_offset] |
200 | : nullptr; |
201 | |
202 | // required to compute zp compensation |
203 | int tmp_neg_a_zp_val = -1; |
204 | ker_exec_ctx.zp_a_neg_value_ptr = &tmp_neg_a_zp_val; |
205 | |
206 | int k_blk_idx = 0; |
207 | for (; k_blk_idx < kernel_conf.K / kernel_conf.K_blk; |
208 | k_blk_idx++) { |
209 | const auto k = k_blk_idx * kernel_conf.K_blk; |
210 | ker_exec_ctx.src = (void *)&src[get_blk_off( |
211 | src_d, sdt_sz, batch, k, n)]; |
212 | ker_exec_ctx.tr_src = (void *)&dst[get_blk_off( |
213 | dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx)]; |
214 | ker_exec_ctx.current_K_start = k; |
215 | ker_exec_ctx.current_K_iters = kernel_conf.K_blk; |
216 | (*kernel_)(&ker_exec_ctx); |
217 | } |
218 | if (kernel_conf.K_tail > 0) { |
219 | const auto k = k_blk_idx * kernel_conf.K_blk; |
220 | ker_exec_ctx.src = (void *)&src[get_blk_off( |
221 | src_d, sdt_sz, batch, k, n)]; |
222 | const auto dst_offset = get_blk_off( |
223 | dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx); |
224 | ker_exec_ctx.tr_src = (void *)&dst[dst_offset]; |
225 | ker_exec_ctx.current_K_start = k; |
226 | ker_exec_ctx.current_K_iters = kernel_conf.K_tail; |
227 | (*kernel_)(&ker_exec_ctx); |
228 | const auto vnni_granularity |
229 | = data_type_vnni_granularity(type_o); |
230 | const auto dst_zero_out_offset |
231 | = rnd_up(kernel_conf.K_tail, vnni_granularity) |
232 | * kernel_conf.N_blk * ddt_sz; |
233 | const auto elems_to_zero |
234 | = rnd_dn(kernel_conf.K_blk - kernel_conf.K_tail, |
235 | vnni_granularity) |
236 | * kernel_conf.N_blk * ddt_sz; |
237 | array_set(&dst[dst_offset + dst_zero_out_offset], 0, |
238 | elems_to_zero); |
239 | } |
240 | }); |
241 | |
242 | #undef get_blk_off |
243 | |
244 | return status::success; |
245 | } |
246 | |
247 | } // namespace x64 |
248 | } // namespace cpu |
249 | } // namespace impl |
250 | } // namespace dnnl |
251 | |