1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_GEMM_BF16_INNER_PRODUCT_HPP |
18 | #define CPU_X64_GEMM_BF16_INNER_PRODUCT_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include <memory> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/dnnl_thread.hpp" |
26 | #include "common/memory_tracking.hpp" |
27 | #include "common/primitive.hpp" |
28 | #include "common/type_helpers.hpp" |
29 | #include "common/utils.hpp" |
30 | |
31 | #include "cpu/cpu_engine.hpp" |
32 | #include "cpu/gemm/gemm.hpp" |
33 | #include "cpu/gemm_inner_product_utils.hpp" |
34 | |
35 | #include "cpu/x64/jit_uni_convert_xf16.hpp" |
36 | |
37 | namespace dnnl { |
38 | namespace impl { |
39 | namespace cpu { |
40 | namespace x64 { |
41 | |
42 | template <data_type_t dst_data_type> |
43 | struct gemm_bf16_inner_product_fwd_t : public primitive_t { |
44 | struct pd_t : public cpu_inner_product_fwd_pd_t { |
45 | using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; |
46 | |
47 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_inner_product_fwd_t); |
48 | |
49 | status_t init(engine_t *engine) { |
50 | using namespace utils; |
51 | using namespace data_type; |
52 | |
53 | bool ok = true && mayiuse(avx512_core) && is_fwd() |
54 | && !has_zero_dim_memory() |
55 | && everyone_is( |
56 | bf16, src_md()->data_type, weights_md()->data_type) |
57 | && dst_data_type == dst_md()->data_type |
58 | && IMPLICATION(with_bias(), |
59 | one_of(weights_md(1)->data_type, f32, bf16)) |
60 | && attr()->has_default_values( |
61 | primitive_attr_t::skip_mask_t::post_ops, |
62 | dst_md()->data_type) |
63 | && attr()->post_ops_.check_sum_consistent_dt( |
64 | dst_md()->data_type) |
65 | && inner_product_utils::post_ops_ok( |
66 | attr()->post_ops_, &dst_md_) |
67 | && set_default_params() == status::success |
68 | && dense_gemm_consitency_check( |
69 | src_md(), weights_md(), dst_md()) |
70 | && attr_.set_default_formats(dst_md(0)) == status::success; |
71 | if (!ok) return status::unimplemented; |
72 | |
73 | dst_is_acc_ = dst_data_type == f32; |
74 | |
75 | init_scratchpad(); |
76 | |
77 | return status::success; |
78 | } |
79 | |
80 | bool dst_is_acc_; |
81 | |
82 | protected: |
83 | void init_scratchpad() { |
84 | if (!dst_is_acc_) { |
85 | auto scratchpad = scratchpad_registry().registrar(); |
86 | scratchpad.template book<acc_data_t>( |
87 | memory_tracking::names::key_iprod_int_dat_in_acc_dt, |
88 | MB() * OC()); |
89 | } |
90 | } |
91 | }; |
92 | |
93 | gemm_bf16_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
94 | |
95 | typedef typename prec_traits<dst_data_type>::type dst_data_t; |
96 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
97 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
98 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
99 | |
100 | status_t init(engine_t *engine) override { |
101 | const bool has_bias = pd()->with_bias(); |
102 | const bool has_eltwise |
103 | = pd()->attr()->post_ops_.find(primitive_kind::eltwise) >= 0; |
104 | const bool has_binary |
105 | = pd()->attr()->post_ops_.find(primitive_kind::binary) >= 0; |
106 | const bool has_sum_as_postops = !pd()->dst_is_acc_; |
107 | postops_in_ip_ = false |
108 | || !pd()->dst_is_acc_ /* includes has_sum_as_postops */ |
109 | || has_bias || has_eltwise || has_binary; |
110 | if (postops_in_ip_) |
111 | CHECK(safe_ptr_assign(pp_kernel_, |
112 | inner_product_utils::pp_kernel_t::create( |
113 | pd(), !has_sum_as_postops))); |
114 | |
115 | auto sum_idx = pd()->attr()->post_ops_.find(primitive_kind::sum); |
116 | beta_ = sum_idx >= 0 && !has_sum_as_postops |
117 | ? pd()->attr()->post_ops_.entry_[sum_idx].sum.scale |
118 | : 0.0; |
119 | |
120 | return (pp_kernel_) ? pp_kernel_->create_kernel() : status::success; |
121 | } |
122 | |
123 | status_t execute(const exec_ctx_t &ctx) const override { |
124 | return execute_forward(ctx); |
125 | } |
126 | |
127 | private: |
128 | std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_; |
129 | bool postops_in_ip_; |
130 | float beta_; |
131 | |
132 | status_t execute_forward(const exec_ctx_t &ctx) const; |
133 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
134 | }; |
135 | |
136 | template <data_type_t diff_src_data_type> |
137 | struct gemm_bf16_inner_product_bwd_data_t : public primitive_t { |
138 | struct pd_t : public cpu_inner_product_bwd_data_pd_t { |
139 | using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; |
140 | |
141 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_inner_product_bwd_data_t); |
142 | |
143 | status_t init(engine_t *engine) { |
144 | using namespace data_type; |
145 | |
146 | bool ok = true && mayiuse(avx512_core) |
147 | && desc()->prop_kind == prop_kind::backward_data |
148 | && !has_zero_dim_memory() |
149 | && utils::everyone_is(bf16, weights_md()->data_type, |
150 | diff_dst_md()->data_type) |
151 | && diff_src_data_type == diff_src_md()->data_type |
152 | && attr()->has_default_values() |
153 | && this->set_default_params() == status::success |
154 | && dense_gemm_consitency_check( |
155 | diff_src_md(), weights_md(), diff_dst_md()); |
156 | if (!ok) return status::unimplemented; |
157 | |
158 | diff_src_is_acc_ = diff_src_data_type == data_type::f32; |
159 | |
160 | init_scratchpad(); |
161 | |
162 | return status::success; |
163 | } |
164 | |
165 | bool diff_src_is_acc_; |
166 | |
167 | private: |
168 | void init_scratchpad() { |
169 | if (!diff_src_is_acc_) { |
170 | auto scratchpad = scratchpad_registry().registrar(); |
171 | scratchpad.template book<acc_data_t>( |
172 | memory_tracking::names::key_iprod_int_dat_in_acc_dt, |
173 | MB() * IC_total_padded()); |
174 | } |
175 | } |
176 | }; |
177 | |
178 | gemm_bf16_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
179 | |
180 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
181 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
182 | typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t; |
183 | typedef typename prec_traits<data_type::bf16>::type wei_data_t; |
184 | |
185 | status_t execute(const exec_ctx_t &ctx) const override { |
186 | return execute_backward_data(ctx); |
187 | } |
188 | |
189 | private: |
190 | status_t execute_backward_data(const exec_ctx_t &ctx) const; |
191 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
192 | }; |
193 | |
194 | template <data_type_t diff_wei_data_type> |
195 | struct gemm_bf16_inner_product_bwd_weights_t : public primitive_t { |
196 | struct pd_t : public cpu_inner_product_bwd_weights_pd_t { |
197 | using cpu_inner_product_bwd_weights_pd_t:: |
198 | cpu_inner_product_bwd_weights_pd_t; |
199 | |
200 | DECLARE_COMMON_PD_T( |
201 | GEMM_IMPL_STR, gemm_bf16_inner_product_bwd_weights_t); |
202 | |
203 | status_t init(engine_t *engine) { |
204 | using namespace utils; |
205 | using namespace data_type; |
206 | |
207 | bool ok = true && mayiuse(avx512_core) |
208 | && desc()->prop_kind == prop_kind::backward_weights |
209 | && !has_zero_dim_memory() |
210 | && everyone_is( |
211 | bf16, src_md()->data_type, diff_dst_md()->data_type) |
212 | && diff_wei_data_type == diff_weights_md()->data_type |
213 | && IMPLICATION(with_bias(), |
214 | one_of(diff_weights_md(1)->data_type, f32, bf16)) |
215 | && attr()->has_default_values() |
216 | && set_default_params() == status::success |
217 | && dense_gemm_consitency_check( |
218 | src_md(), diff_weights_md(), diff_dst_md()); |
219 | |
220 | if (!ok) return status::unimplemented; |
221 | |
222 | diff_wei_is_acc_ = diff_wei_data_type == f32; |
223 | bias_reduction_nthr_ = dnnl_get_max_threads(); |
224 | |
225 | init_scratchpad(); |
226 | |
227 | return status::success; |
228 | } |
229 | |
230 | bool diff_wei_is_acc_; |
231 | int bias_reduction_nthr_; |
232 | static const dim_t bias_blksize = 32; |
233 | |
234 | void get_bias_partitioning( |
235 | dim_t &OC_per_thread, int &nthr_OCB, int &nthr_MB) const { |
236 | dim_t OCB = utils::div_up(OC(), bias_blksize); |
237 | dim_t OCB_per_thread = utils::div_up(OCB, bias_reduction_nthr_); |
238 | |
239 | OC_per_thread = OCB_per_thread * bias_blksize; |
240 | nthr_OCB = utils::div_up(OCB, OCB_per_thread); |
241 | nthr_MB = bias_reduction_nthr_ / nthr_OCB; |
242 | |
243 | assert(nthr_OCB * nthr_MB <= bias_reduction_nthr_); |
244 | } |
245 | |
246 | private: |
247 | void init_scratchpad() { |
248 | using namespace memory_tracking; |
249 | auto scratchpad = scratchpad_registry().registrar(); |
250 | |
251 | if (!diff_wei_is_acc_) |
252 | scratchpad.template book<acc_data_t>( |
253 | names::key_iprod_int_dat_in_acc_dt, |
254 | OC() * IC_total_padded()); |
255 | |
256 | if (with_bias()) { |
257 | dim_t OC_per_thread {0}; |
258 | int nthr_OCB {0}, nthr_MB {0}; |
259 | get_bias_partitioning(OC_per_thread, nthr_OCB, nthr_MB); |
260 | |
261 | const bool diff_bias_is_acc = nthr_MB == 1 |
262 | && diff_weights_md(1)->data_type == data_type::f32; |
263 | |
264 | if (!diff_bias_is_acc) |
265 | scratchpad.template book<acc_data_t>( |
266 | names::key_iprod_bias_bf16_convert_wsp, |
267 | nthr_OCB * nthr_MB * OC_per_thread); |
268 | } |
269 | } |
270 | }; |
271 | |
272 | gemm_bf16_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
273 | |
274 | status_t init(engine_t *engine) override { |
275 | if (pd()->with_bias()) |
276 | CHECK(safe_ptr_assign(bias_reduction_, |
277 | new jit_cvt_xf16_to_ps_t( |
278 | data_type::bf16, true, (size_t)pd()->OC()))); |
279 | return status::success; |
280 | } |
281 | |
282 | typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t; |
283 | typedef typename prec_traits<data_type::f32>::type acc_data_t; |
284 | typedef typename prec_traits<data_type::bf16>::type src_data_t; |
285 | typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t; |
286 | |
287 | status_t execute(const exec_ctx_t &ctx) const override { |
288 | return execute_backward_weights(ctx); |
289 | } |
290 | |
291 | private: |
292 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
293 | |
294 | status_t execute_backward_weights(const exec_ctx_t &ctx) const; |
295 | void execute_backward_bias(const exec_ctx_t &ctx) const; |
296 | |
297 | std::unique_ptr<jit_cvt_xf16_to_ps_t> bias_reduction_; |
298 | }; |
299 | |
300 | } // namespace x64 |
301 | } // namespace cpu |
302 | } // namespace impl |
303 | } // namespace dnnl |
304 | |
305 | #endif |
306 | |
307 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
308 | |