1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_MATMUL_GEMM_F32_MATMUL_HPP
18#define CPU_MATMUL_GEMM_F32_MATMUL_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "common/type_helpers.hpp"
25
26#include "cpu/gemm_inner_product_utils.hpp"
27
28#include "cpu/matmul/cpu_matmul_pd.hpp"
29#include "cpu/matmul/gemm_based_common.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace matmul {
35
36struct gemm_f32_matmul_t : public primitive_t {
37 struct pd_t : public cpu_matmul_pd_t {
38 using cpu_matmul_pd_t::cpu_matmul_pd_t;
39
40 DECLARE_COMMON_PD_T("gemm:jit", gemm_f32_matmul_t);
41
42 status_t init(engine_t *engine);
43 const gemm_based::params_t &params() const { return params_; }
44
45 int nthr_; // To not exceed the limit in execute used for set up.
46
47 private:
48 status_t check_and_configure_attributes();
49 gemm_based::params_t params_;
50 };
51
52 gemm_f32_matmul_t(const pd_t *apd) : primitive_t(apd) {}
53
54 status_t init(engine_t *engine) override {
55 if (pd()->params().has_pp_kernel_) {
56 const bool has_runtime_dims
57 = memory_desc_wrapper(pd()->dst_md()).has_runtime_dims();
58 const int nthr = pd()->nthr_;
59 const dim_t batch = pd()->batch();
60 const dim_t M = pd()->M();
61
62 // mb value is calculated based on work-sharing using
63 // balance211 in execute()
64 dim_t mb = DNNL_RUNTIME_DIM_VAL;
65 if (!has_runtime_dims && ((batch * M) % nthr == 0)) {
66 const dim_t m_per_thr = nstl::max<dim_t>(1, (batch * M) / nthr);
67 if (m_per_thr >= M && m_per_thr % M == 0) {
68 mb = M;
69 } else if (m_per_thr < M && M % m_per_thr == 0) {
70 mb = m_per_thr;
71 }
72 }
73
74 const bool skip_sum = should_skip_sum_po(
75 pd()->dst_md()
76 ->data_type); // sum can be done by gemm itself
77 CHECK(safe_ptr_assign(pp_kernel_,
78 inner_product_utils::pp_kernel_t::create(pd()->N(), mb,
79 pd()->ldc(), &pd()->params().pp_attr_,
80 pd()->desc()->bias_desc.data_type,
81 pd()->desc()->accum_data_type, pd()->dst_md(),
82 skip_sum)));
83 return pp_kernel_->create_kernel();
84 }
85
86 return status::success;
87 }
88
89 static constexpr data_type_t src_type = data_type::f32;
90 static constexpr data_type_t weights_type = data_type::f32;
91 static constexpr data_type_t dst_type = data_type::f32;
92 static constexpr data_type_t acc_type = data_type::f32;
93
94 typedef typename prec_traits<src_type>::type src_data_t;
95 typedef typename prec_traits<weights_type>::type weights_data_t;
96 typedef typename prec_traits<dst_type>::type dst_data_t;
97 typedef typename prec_traits<acc_type>::type acc_data_t;
98
99 status_t execute(const exec_ctx_t &ctx) const override {
100 return execute_ref(ctx);
101 }
102
103private:
104 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
105 status_t execute_ref(const exec_ctx_t &ctx) const;
106 bool should_skip_sum_po(data_type_t dst_dt) const noexcept;
107
108 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
109};
110
111} // namespace matmul
112} // namespace cpu
113} // namespace impl
114} // namespace dnnl
115
116#endif
117