1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_MATMUL_GEMM_BF16_MATMUL_HPP
18#define CPU_MATMUL_GEMM_BF16_MATMUL_HPP
19
20#include <assert.h>
21
22#include "common/bfloat16.hpp"
23#include "common/c_types_map.hpp"
24#include "common/primitive.hpp"
25#include "common/type_helpers.hpp"
26
27#include "cpu/gemm_inner_product_utils.hpp"
28
29#include "cpu/matmul/cpu_matmul_pd.hpp"
30#include "cpu/matmul/gemm_based_common.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace matmul {
36
37template <impl::data_type_t dst_type>
38struct gemm_bf16_matmul_t : public primitive_t {
39 struct pd_t : public cpu_matmul_pd_t {
40 using cpu_matmul_pd_t::cpu_matmul_pd_t;
41
42 DECLARE_COMMON_PD_T("gemm:jit", gemm_bf16_matmul_t);
43
44 status_t init(engine_t *engine);
45 const gemm_based::params_t &params() const { return params_; }
46
47 int nthr_; // To not exceed the limit in execute used for set up.
48
49 private:
50 status_t check_and_configure_attributes();
51 gemm_based::params_t params_;
52 };
53
54 gemm_bf16_matmul_t(const pd_t *apd) : primitive_t(apd) {}
55
56 status_t init(engine_t *engine) override {
57 if (pd()->params().has_pp_kernel_) {
58 const bool has_runtime_dims
59 = memory_desc_wrapper(pd()->dst_md()).has_runtime_dims();
60 const int nthr = pd()->nthr_;
61 const dim_t batch = pd()->batch();
62 const dim_t M = pd()->M();
63
64 // mb value is calculated based on work-sharing using
65 // balance211 in execute()
66 dim_t mb = DNNL_RUNTIME_DIM_VAL;
67 if (!has_runtime_dims && ((batch * M) % nthr == 0)) {
68 const dim_t m_per_thr = nstl::max<dim_t>(1, (batch * M) / nthr);
69 if (m_per_thr >= M && m_per_thr % M == 0) {
70 mb = M;
71 } else if (m_per_thr < M && M % m_per_thr == 0) {
72 mb = m_per_thr;
73 }
74 }
75
76 const bool skip_sum
77 = should_skip_sum_po(); // sum can be done by gemm itself
78 CHECK(safe_ptr_assign(pp_kernel_,
79 inner_product_utils::pp_kernel_t::create(pd()->N(), mb,
80 pd()->ldc(), &pd()->params().pp_attr_,
81 pd()->desc()->bias_desc.data_type,
82 pd()->desc()->accum_data_type, pd()->dst_md(),
83 skip_sum)));
84 return pp_kernel_->create_kernel();
85 }
86 return status::success;
87 }
88
89 static constexpr data_type_t src_type = data_type::bf16;
90 static constexpr data_type_t weights_type = data_type::bf16;
91 static constexpr data_type_t acc_type = data_type::f32;
92
93 typedef typename prec_traits<src_type>::type src_data_t;
94 typedef typename prec_traits<weights_type>::type weights_data_t;
95 typedef typename prec_traits<dst_type>::type dst_data_t;
96 typedef typename prec_traits<acc_type>::type acc_data_t;
97
98 status_t execute(const exec_ctx_t &ctx) const override {
99 return execute_ref(ctx);
100 }
101
102private:
103 bool should_skip_sum_po() const noexcept;
104 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
105 status_t execute_ref(const exec_ctx_t &ctx) const;
106
107 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
108};
109
110} // namespace matmul
111} // namespace cpu
112} // namespace impl
113} // namespace dnnl
114
115#endif
116