1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_INNER_PRODUCT_HPP
18#define CPU_GEMM_INNER_PRODUCT_HPP
19
20#include <assert.h>
21
22#include <memory>
23
24#include "common/c_types_map.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/gemm/gemm.hpp"
30#include "cpu/gemm_inner_product_utils.hpp"
31
32#include "cpu/cpu_inner_product_pd.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37
38template <impl::data_type_t data_type>
39struct gemm_inner_product_fwd_t : public primitive_t {
40 struct pd_t : public cpu_inner_product_fwd_pd_t {
41 using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
42
43 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t);
44
45 status_t init(engine_t *engine) {
46 using namespace utils;
47
48 const bool ok = true && is_fwd() && !has_zero_dim_memory()
49 && everyone_is(data_type, src_md()->data_type,
50 weights_md()->data_type, dst_md()->data_type,
51 with_bias() ? weights_md(1)->data_type : data_type)
52 && attr()->has_default_values(
53 primitive_attr_t::skip_mask_t::post_ops
54 | primitive_attr_t::skip_mask_t::sum_dt)
55 && attr()->post_ops_.check_sum_consistent_dt(
56 dst_md()->data_type)
57 && set_default_params() == status::success
58 && dense_gemm_consitency_check(
59 src_md(), weights_md(), dst_md())
60 && inner_product_utils::post_ops_ok(
61 attr()->post_ops_, &dst_md_)
62 && attr_.set_default_formats(dst_md(0)) == status::success;
63
64 return ok ? status::success : status::unimplemented;
65 }
66 };
67
68 gemm_inner_product_fwd_t(const pd_t *apd)
69 : primitive_t(apd), postops_in_ip_(false) {}
70
71 status_t init(engine_t *engine) override {
72 const bool has_bias = pd()->with_bias();
73 const bool has_eltwise
74 = pd()->attr()->post_ops_.find(primitive_kind::eltwise) >= 0;
75 const bool has_binary
76 = pd()->attr()->post_ops_.find(primitive_kind::binary) >= 0;
77 postops_in_ip_ = has_bias || has_eltwise || has_binary;
78
79 CHECK(safe_ptr_assign(pp_kernel_,
80 inner_product_utils::pp_kernel_t::create(pd(), true)));
81
82 auto sum_idx = pd()->attr()->post_ops_.find(primitive_kind::sum);
83 beta_ = sum_idx >= 0 ? pd()->attr()->post_ops_.entry_[sum_idx].sum.scale
84 : 0.0;
85
86 return pp_kernel_->create_kernel();
87 }
88
89 typedef typename prec_traits<data_type>::type data_t;
90
91 status_t execute(const exec_ctx_t &ctx) const override {
92 return execute_forward(ctx);
93 }
94
95private:
96 status_t execute_forward(const exec_ctx_t &ctx) const;
97 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
98
99 std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_;
100 bool postops_in_ip_;
101 float beta_;
102};
103
104template <impl::data_type_t data_type>
105struct gemm_inner_product_bwd_data_t : public primitive_t {
106 struct pd_t : public cpu_inner_product_bwd_data_pd_t {
107 using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
108
109 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t);
110
111 status_t init(engine_t *engine) {
112 bool ok = true && desc()->prop_kind == prop_kind::backward_data
113 && !has_zero_dim_memory()
114 && utils::everyone_is(data_type, diff_src_md()->data_type,
115 weights_md()->data_type, diff_dst_md()->data_type)
116 && attr()->has_default_values()
117 && set_default_params() == status::success
118 && dense_gemm_consitency_check(
119 diff_src_md(), weights_md(), diff_dst_md());
120 return ok ? status::success : status::unimplemented;
121 }
122 };
123
124 gemm_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
125 typedef typename prec_traits<data_type>::type data_t;
126
127 status_t execute(const exec_ctx_t &ctx) const override {
128 return execute_backward_data(ctx);
129 }
130
131private:
132 status_t execute_backward_data(const exec_ctx_t &ctx) const;
133 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
134};
135
136template <impl::data_type_t data_type>
137struct gemm_inner_product_bwd_weights_t : public primitive_t {
138 struct pd_t : public cpu_inner_product_bwd_weights_pd_t {
139 using cpu_inner_product_bwd_weights_pd_t::
140 cpu_inner_product_bwd_weights_pd_t;
141
142 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t);
143
144 status_t init(engine_t *engine) {
145 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
146 && !has_zero_dim_memory()
147 && utils::everyone_is(data_type, src_md()->data_type,
148 diff_weights_md()->data_type,
149 diff_dst_md()->data_type,
150 with_bias() ? diff_weights_md(1)->data_type
151 : data_type)
152 && attr()->has_default_values()
153 && set_default_params() == status::success
154 && dense_gemm_consitency_check(
155 src_md(), diff_weights_md(), diff_dst_md());
156
157 return ok ? status::success : status::unimplemented;
158 }
159 };
160
161 gemm_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
162 typedef typename prec_traits<data_type>::type data_t;
163
164 status_t execute(const exec_ctx_t &ctx) const override {
165 return execute_backward_weights(ctx);
166 }
167
168private:
169 status_t execute_backward_weights(const exec_ctx_t &ctx) const;
170 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
171};
172
173} // namespace cpu
174} // namespace impl
175} // namespace dnnl
176
177#endif
178
179// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
180