1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_GEMM_BF16_INNER_PRODUCT_HPP
18#define CPU_X64_GEMM_BF16_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include <memory>
23
24#include "common/c_types_map.hpp"
25#include "common/dnnl_thread.hpp"
26#include "common/memory_tracking.hpp"
27#include "common/primitive.hpp"
28#include "common/type_helpers.hpp"
29#include "common/utils.hpp"
30
31#include "cpu/cpu_engine.hpp"
32#include "cpu/gemm/gemm.hpp"
33#include "cpu/gemm_inner_product_utils.hpp"
34
35#include "cpu/x64/jit_uni_convert_xf16.hpp"
36
37namespace dnnl {
38namespace impl {
39namespace cpu {
40namespace x64 {
41
42template <data_type_t dst_data_type>
43struct gemm_bf16_inner_product_fwd_t : public primitive_t {
44 struct pd_t : public cpu_inner_product_fwd_pd_t {
45 using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
46
47 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_inner_product_fwd_t);
48
49 status_t init(engine_t *engine) {
50 using namespace utils;
51 using namespace data_type;
52
53 bool ok = true && mayiuse(avx512_core) && is_fwd()
54 && !has_zero_dim_memory()
55 && everyone_is(
56 bf16, src_md()->data_type, weights_md()->data_type)
57 && dst_data_type == dst_md()->data_type
58 && IMPLICATION(with_bias(),
59 one_of(weights_md(1)->data_type, f32, bf16))
60 && attr()->has_default_values(
61 primitive_attr_t::skip_mask_t::post_ops,
62 dst_md()->data_type)
63 && attr()->post_ops_.check_sum_consistent_dt(
64 dst_md()->data_type)
65 && inner_product_utils::post_ops_ok(
66 attr()->post_ops_, &dst_md_)
67 && set_default_params() == status::success
68 && dense_gemm_consitency_check(
69 src_md(), weights_md(), dst_md())
70 && attr_.set_default_formats(dst_md(0)) == status::success;
71 if (!ok) return status::unimplemented;
72
73 dst_is_acc_ = dst_data_type == f32;
74
75 init_scratchpad();
76
77 return status::success;
78 }
79
80 bool dst_is_acc_;
81
82 protected:
83 void init_scratchpad() {
84 if (!dst_is_acc_) {
85 auto scratchpad = scratchpad_registry().registrar();
86 scratchpad.template book<acc_data_t>(
87 memory_tracking::names::key_iprod_int_dat_in_acc_dt,
88 MB() * OC());
89 }
90 }
91 };
92
93 gemm_bf16_inner_product_fwd_t(const pd_t *apd) : primitive_t(apd) {}
94
95 typedef typename prec_traits<dst_data_type>::type dst_data_t;
96 typedef typename prec_traits<data_type::f32>::type acc_data_t;
97 typedef typename prec_traits<data_type::bf16>::type src_data_t;
98 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
99
100 status_t init(engine_t *engine) override {
101 const bool has_bias = pd()->with_bias();
102 const bool has_eltwise
103 = pd()->attr()->post_ops_.find(primitive_kind::eltwise) >= 0;
104 const bool has_binary
105 = pd()->attr()->post_ops_.find(primitive_kind::binary) >= 0;
106 const bool has_sum_as_postops = !pd()->dst_is_acc_;
107 postops_in_ip_ = false
108 || !pd()->dst_is_acc_ /* includes has_sum_as_postops */
109 || has_bias || has_eltwise || has_binary;
110 if (postops_in_ip_)
111 CHECK(safe_ptr_assign(pp_kernel_,
112 inner_product_utils::pp_kernel_t::create(
113 pd(), !has_sum_as_postops)));
114
115 auto sum_idx = pd()->attr()->post_ops_.find(primitive_kind::sum);
116 beta_ = sum_idx >= 0 && !has_sum_as_postops
117 ? pd()->attr()->post_ops_.entry_[sum_idx].sum.scale
118 : 0.0;
119
120 return (pp_kernel_) ? pp_kernel_->create_kernel() : status::success;
121 }
122
123 status_t execute(const exec_ctx_t &ctx) const override {
124 return execute_forward(ctx);
125 }
126
127private:
128 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
129 bool postops_in_ip_;
130 float beta_;
131
132 status_t execute_forward(const exec_ctx_t &ctx) const;
133 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
134};
135
136template <data_type_t diff_src_data_type>
137struct gemm_bf16_inner_product_bwd_data_t : public primitive_t {
138 struct pd_t : public cpu_inner_product_bwd_data_pd_t {
139 using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
140
141 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_bf16_inner_product_bwd_data_t);
142
143 status_t init(engine_t *engine) {
144 using namespace data_type;
145
146 bool ok = true && mayiuse(avx512_core)
147 && desc()->prop_kind == prop_kind::backward_data
148 && !has_zero_dim_memory()
149 && utils::everyone_is(bf16, weights_md()->data_type,
150 diff_dst_md()->data_type)
151 && diff_src_data_type == diff_src_md()->data_type
152 && attr()->has_default_values()
153 && this->set_default_params() == status::success
154 && dense_gemm_consitency_check(
155 diff_src_md(), weights_md(), diff_dst_md());
156 if (!ok) return status::unimplemented;
157
158 diff_src_is_acc_ = diff_src_data_type == data_type::f32;
159
160 init_scratchpad();
161
162 return status::success;
163 }
164
165 bool diff_src_is_acc_;
166
167 private:
168 void init_scratchpad() {
169 if (!diff_src_is_acc_) {
170 auto scratchpad = scratchpad_registry().registrar();
171 scratchpad.template book<acc_data_t>(
172 memory_tracking::names::key_iprod_int_dat_in_acc_dt,
173 MB() * IC_total_padded());
174 }
175 }
176 };
177
178 gemm_bf16_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
179
180 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
181 typedef typename prec_traits<data_type::f32>::type acc_data_t;
182 typedef typename prec_traits<diff_src_data_type>::type diff_src_data_t;
183 typedef typename prec_traits<data_type::bf16>::type wei_data_t;
184
185 status_t execute(const exec_ctx_t &ctx) const override {
186 return execute_backward_data(ctx);
187 }
188
189private:
190 status_t execute_backward_data(const exec_ctx_t &ctx) const;
191 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
192};
193
194template <data_type_t diff_wei_data_type>
195struct gemm_bf16_inner_product_bwd_weights_t : public primitive_t {
196 struct pd_t : public cpu_inner_product_bwd_weights_pd_t {
197 using cpu_inner_product_bwd_weights_pd_t::
198 cpu_inner_product_bwd_weights_pd_t;
199
200 DECLARE_COMMON_PD_T(
201 GEMM_IMPL_STR, gemm_bf16_inner_product_bwd_weights_t);
202
203 status_t init(engine_t *engine) {
204 using namespace utils;
205 using namespace data_type;
206
207 bool ok = true && mayiuse(avx512_core)
208 && desc()->prop_kind == prop_kind::backward_weights
209 && !has_zero_dim_memory()
210 && everyone_is(
211 bf16, src_md()->data_type, diff_dst_md()->data_type)
212 && diff_wei_data_type == diff_weights_md()->data_type
213 && IMPLICATION(with_bias(),
214 one_of(diff_weights_md(1)->data_type, f32, bf16))
215 && attr()->has_default_values()
216 && set_default_params() == status::success
217 && dense_gemm_consitency_check(
218 src_md(), diff_weights_md(), diff_dst_md());
219
220 if (!ok) return status::unimplemented;
221
222 diff_wei_is_acc_ = diff_wei_data_type == f32;
223 bias_reduction_nthr_ = dnnl_get_max_threads();
224
225 init_scratchpad();
226
227 return status::success;
228 }
229
230 bool diff_wei_is_acc_;
231 int bias_reduction_nthr_;
232 static const dim_t bias_blksize = 32;
233
234 void get_bias_partitioning(
235 dim_t &OC_per_thread, int &nthr_OCB, int &nthr_MB) const {
236 dim_t OCB = utils::div_up(OC(), bias_blksize);
237 dim_t OCB_per_thread = utils::div_up(OCB, bias_reduction_nthr_);
238
239 OC_per_thread = OCB_per_thread * bias_blksize;
240 nthr_OCB = utils::div_up(OCB, OCB_per_thread);
241 nthr_MB = bias_reduction_nthr_ / nthr_OCB;
242
243 assert(nthr_OCB * nthr_MB <= bias_reduction_nthr_);
244 }
245
246 private:
247 void init_scratchpad() {
248 using namespace memory_tracking;
249 auto scratchpad = scratchpad_registry().registrar();
250
251 if (!diff_wei_is_acc_)
252 scratchpad.template book<acc_data_t>(
253 names::key_iprod_int_dat_in_acc_dt,
254 OC() * IC_total_padded());
255
256 if (with_bias()) {
257 dim_t OC_per_thread {0};
258 int nthr_OCB {0}, nthr_MB {0};
259 get_bias_partitioning(OC_per_thread, nthr_OCB, nthr_MB);
260
261 const bool diff_bias_is_acc = nthr_MB == 1
262 && diff_weights_md(1)->data_type == data_type::f32;
263
264 if (!diff_bias_is_acc)
265 scratchpad.template book<acc_data_t>(
266 names::key_iprod_bias_bf16_convert_wsp,
267 nthr_OCB * nthr_MB * OC_per_thread);
268 }
269 }
270 };
271
272 gemm_bf16_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
273
274 status_t init(engine_t *engine) override {
275 if (pd()->with_bias())
276 CHECK(safe_ptr_assign(bias_reduction_,
277 new jit_cvt_xf16_to_ps_t(
278 data_type::bf16, true, (size_t)pd()->OC())));
279 return status::success;
280 }
281
282 typedef typename prec_traits<data_type::bf16>::type diff_dst_data_t;
283 typedef typename prec_traits<data_type::f32>::type acc_data_t;
284 typedef typename prec_traits<data_type::bf16>::type src_data_t;
285 typedef typename prec_traits<diff_wei_data_type>::type diff_wei_data_t;
286
287 status_t execute(const exec_ctx_t &ctx) const override {
288 return execute_backward_weights(ctx);
289 }
290
291private:
292 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
293
294 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
295 void execute_backward_bias(const exec_ctx_t &ctx) const;
296
297 std::unique_ptr<jit_cvt_xf16_to_ps_t> bias_reduction_;
298};
299
300} // namespace x64
301} // namespace cpu
302} // namespace impl
303} // namespace dnnl
304
305#endif
306
307// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
308