1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_GEMM_INNER_PRODUCT_HPP |
18 | #define CPU_GEMM_INNER_PRODUCT_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include <memory> |
23 | |
24 | #include "common/c_types_map.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/gemm/gemm.hpp" |
30 | #include "cpu/gemm_inner_product_utils.hpp" |
31 | |
32 | #include "cpu/cpu_inner_product_pd.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | |
38 | template <impl::data_type_t data_type> |
39 | struct gemm_inner_product_fwd_t : public primitive_t { |
40 | struct pd_t : public cpu_inner_product_fwd_pd_t { |
41 | using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t; |
42 | |
43 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t); |
44 | |
45 | status_t init(engine_t *engine) { |
46 | using namespace utils; |
47 | |
48 | const bool ok = true && is_fwd() && !has_zero_dim_memory() |
49 | && everyone_is(data_type, src_md()->data_type, |
50 | weights_md()->data_type, dst_md()->data_type, |
51 | with_bias() ? weights_md(1)->data_type : data_type) |
52 | && attr()->has_default_values( |
53 | primitive_attr_t::skip_mask_t::post_ops |
54 | | primitive_attr_t::skip_mask_t::sum_dt) |
55 | && attr()->post_ops_.check_sum_consistent_dt( |
56 | dst_md()->data_type) |
57 | && set_default_params() == status::success |
58 | && dense_gemm_consitency_check( |
59 | src_md(), weights_md(), dst_md()) |
60 | && inner_product_utils::post_ops_ok( |
61 | attr()->post_ops_, &dst_md_) |
62 | && attr_.set_default_formats(dst_md(0)) == status::success; |
63 | |
64 | return ok ? status::success : status::unimplemented; |
65 | } |
66 | }; |
67 | |
68 | gemm_inner_product_fwd_t(const pd_t *apd) |
69 | : primitive_t(apd), postops_in_ip_(false) {} |
70 | |
71 | status_t init(engine_t *engine) override { |
72 | const bool has_bias = pd()->with_bias(); |
73 | const bool has_eltwise |
74 | = pd()->attr()->post_ops_.find(primitive_kind::eltwise) >= 0; |
75 | const bool has_binary |
76 | = pd()->attr()->post_ops_.find(primitive_kind::binary) >= 0; |
77 | postops_in_ip_ = has_bias || has_eltwise || has_binary; |
78 | |
79 | CHECK(safe_ptr_assign(pp_kernel_, |
80 | inner_product_utils::pp_kernel_t::create(pd(), true))); |
81 | |
82 | auto sum_idx = pd()->attr()->post_ops_.find(primitive_kind::sum); |
83 | beta_ = sum_idx >= 0 ? pd()->attr()->post_ops_.entry_[sum_idx].sum.scale |
84 | : 0.0; |
85 | |
86 | return pp_kernel_->create_kernel(); |
87 | } |
88 | |
89 | typedef typename prec_traits<data_type>::type data_t; |
90 | |
91 | status_t execute(const exec_ctx_t &ctx) const override { |
92 | return execute_forward(ctx); |
93 | } |
94 | |
95 | private: |
96 | status_t execute_forward(const exec_ctx_t &ctx) const; |
97 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
98 | |
99 | std::unique_ptr<inner_product_utils::pp_kernel_t> pp_kernel_; |
100 | bool postops_in_ip_; |
101 | float beta_; |
102 | }; |
103 | |
104 | template <impl::data_type_t data_type> |
105 | struct gemm_inner_product_bwd_data_t : public primitive_t { |
106 | struct pd_t : public cpu_inner_product_bwd_data_pd_t { |
107 | using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t; |
108 | |
109 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t); |
110 | |
111 | status_t init(engine_t *engine) { |
112 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
113 | && !has_zero_dim_memory() |
114 | && utils::everyone_is(data_type, diff_src_md()->data_type, |
115 | weights_md()->data_type, diff_dst_md()->data_type) |
116 | && attr()->has_default_values() |
117 | && set_default_params() == status::success |
118 | && dense_gemm_consitency_check( |
119 | diff_src_md(), weights_md(), diff_dst_md()); |
120 | return ok ? status::success : status::unimplemented; |
121 | } |
122 | }; |
123 | |
124 | gemm_inner_product_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
125 | typedef typename prec_traits<data_type>::type data_t; |
126 | |
127 | status_t execute(const exec_ctx_t &ctx) const override { |
128 | return execute_backward_data(ctx); |
129 | } |
130 | |
131 | private: |
132 | status_t execute_backward_data(const exec_ctx_t &ctx) const; |
133 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
134 | }; |
135 | |
136 | template <impl::data_type_t data_type> |
137 | struct gemm_inner_product_bwd_weights_t : public primitive_t { |
138 | struct pd_t : public cpu_inner_product_bwd_weights_pd_t { |
139 | using cpu_inner_product_bwd_weights_pd_t:: |
140 | cpu_inner_product_bwd_weights_pd_t; |
141 | |
142 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t); |
143 | |
144 | status_t init(engine_t *engine) { |
145 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
146 | && !has_zero_dim_memory() |
147 | && utils::everyone_is(data_type, src_md()->data_type, |
148 | diff_weights_md()->data_type, |
149 | diff_dst_md()->data_type, |
150 | with_bias() ? diff_weights_md(1)->data_type |
151 | : data_type) |
152 | && attr()->has_default_values() |
153 | && set_default_params() == status::success |
154 | && dense_gemm_consitency_check( |
155 | src_md(), diff_weights_md(), diff_dst_md()); |
156 | |
157 | return ok ? status::success : status::unimplemented; |
158 | } |
159 | }; |
160 | |
161 | gemm_inner_product_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
162 | typedef typename prec_traits<data_type>::type data_t; |
163 | |
164 | status_t execute(const exec_ctx_t &ctx) const override { |
165 | return execute_backward_weights(ctx); |
166 | } |
167 | |
168 | private: |
169 | status_t execute_backward_weights(const exec_ctx_t &ctx) const; |
170 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
171 | }; |
172 | |
173 | } // namespace cpu |
174 | } // namespace impl |
175 | } // namespace dnnl |
176 | |
177 | #endif |
178 | |
179 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
180 | |