1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_GEMM_CONVOLUTION_HPP |
18 | #define CPU_GEMM_CONVOLUTION_HPP |
19 | |
20 | #include "common/broadcast_strategy.hpp" |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/memory_tracking.hpp" |
23 | #include "common/primitive.hpp" |
24 | |
25 | #include "cpu/cpu_convolution_pd.hpp" |
26 | #include "cpu/gemm/gemm.hpp" |
27 | #include "cpu/gemm_convolution_utils.hpp" |
28 | #include "cpu/primitive_attr_postops.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | struct gemm_convolution_fwd_t : public primitive_t { |
35 | struct pd_t : public cpu_convolution_fwd_pd_t { |
36 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
37 | const typename pd_t::base_class *hint_fwd_pd) |
38 | : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
39 | |
40 | DECLARE_COMMON_PD_T( |
41 | GEMM_IMPL_STR, gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); |
42 | |
43 | status_t init(engine_t *engine) { |
44 | using namespace data_type; |
45 | |
46 | bool ok = is_fwd() |
47 | && set_default_alg_kind(alg_kind::convolution_direct) |
48 | && expect_data_types(f32, f32, f32, f32, f32) |
49 | && !has_zero_dim_memory() |
50 | && attr()->has_default_values( |
51 | primitive_attr_t::skip_mask_t::post_ops, f32) |
52 | && post_ops_ok(); |
53 | if (!ok) return status::unimplemented; |
54 | |
55 | auto scratchpad = scratchpad_registry().registrar(); |
56 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
57 | *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_, |
58 | dnnl_get_max_threads()); |
59 | } |
60 | |
61 | conv_gemm_conf_t jcp_; |
62 | |
63 | protected: |
64 | bool post_ops_ok() const { |
65 | auto const &po = attr()->post_ops_; |
66 | auto is_eltwise |
67 | = [&](int idx) { return po.entry_[idx].is_eltwise(); }; |
68 | auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); }; |
69 | auto is_binary |
70 | = [&](int idx) { return po.entry_[idx].is_binary(); }; |
71 | |
72 | for (int idx = 0; idx < po.len(); idx++) { |
73 | bool ok = utils::one_of(true, is_sum(idx), is_binary(idx), |
74 | is_eltwise(idx)) |
75 | && IMPLICATION(is_sum(idx), idx == 0) |
76 | && IMPLICATION(is_binary(idx), |
77 | dnnl::impl::get_rhs_arg_broadcasting_strategy( |
78 | po.entry_[idx].binary.src1_desc, |
79 | dst_md_, |
80 | {broadcasting_strategy_t::scalar, |
81 | broadcasting_strategy_t:: |
82 | per_oc}) |
83 | != broadcasting_strategy_t:: |
84 | unsupported); |
85 | if (!ok) return false; |
86 | } |
87 | |
88 | return true; |
89 | } |
90 | }; |
91 | |
92 | gemm_convolution_fwd_t(const pd_t *apd) |
93 | : primitive_t(apd), post_ops_(nullptr) {} |
94 | |
95 | status_t init(engine_t *engine) override { |
96 | const data_t one = 1.0, zero = 0.0; |
97 | const auto &jcp = pd()->jcp_; |
98 | beta_ = jcp.with_sum ? one : zero; |
99 | |
100 | if (jcp.with_eltwise || jcp.with_binary) |
101 | CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops))); |
102 | return status::success; |
103 | } |
104 | |
105 | typedef typename prec_traits<data_type::f32>::type data_t; |
106 | |
107 | status_t execute(const exec_ctx_t &ctx) const override { |
108 | bool is_nspc = pd()->jcp_.is_nspc; |
109 | return is_nspc ? execute_forward_nspc(ctx) : execute_forward_ncsp(ctx); |
110 | } |
111 | |
112 | private: |
113 | status_t execute_forward_ncsp(const exec_ctx_t &ctx) const; |
114 | status_t execute_forward_nspc(const exec_ctx_t &ctx) const; |
115 | status_t execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr, |
116 | const int nthr, const data_t *src_base, const data_t *wei_base, |
117 | const data_t *bia_base, data_t *dst_base, |
118 | const memory_tracking::grantor_t &scratchpad) const; |
119 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
120 | |
121 | data_t beta_; |
122 | |
123 | std::unique_ptr<ref_post_ops_t> post_ops_; |
124 | }; |
125 | |
126 | struct gemm_convolution_bwd_data_t : public primitive_t { |
127 | struct pd_t : public cpu_convolution_bwd_data_pd_t { |
128 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
129 | const convolution_fwd_pd_t *hint_fwd_pd) |
130 | : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {} |
131 | |
132 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t, |
133 | USE_GLOBAL_SCRATCHPAD); |
134 | |
135 | status_t init(engine_t *engine) { |
136 | bool ok = true && desc()->prop_kind == prop_kind::backward_data |
137 | && set_default_alg_kind(alg_kind::convolution_direct) |
138 | && expect_data_types(data_type::f32, data_type::f32, |
139 | data_type::undef, data_type::f32, data_type::f32) |
140 | && !has_zero_dim_memory() && attr()->has_default_values(); |
141 | if (!ok) return status::unimplemented; |
142 | |
143 | auto scratchpad = scratchpad_registry().registrar(); |
144 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
145 | *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_, |
146 | attr_, dnnl_get_max_threads()); |
147 | } |
148 | |
149 | conv_gemm_conf_t jcp_; |
150 | }; |
151 | |
152 | gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} |
153 | |
154 | typedef typename prec_traits<data_type::f32>::type data_t; |
155 | |
156 | status_t execute(const exec_ctx_t &ctx) const override { |
157 | bool is_nspc = pd()->jcp_.is_nspc; |
158 | return is_nspc ? execute_backward_data_nspc(ctx) |
159 | : execute_backward_data_ncsp(ctx); |
160 | } |
161 | |
162 | private: |
163 | status_t execute_backward_data_nspc(const exec_ctx_t &ctx) const; |
164 | status_t execute_backward_data_ncsp(const exec_ctx_t &ctx) const; |
165 | status_t execute_backward_data_thr_nspc(const int ithr, const int nthr, |
166 | const data_t *diff_dst_base, const data_t *wei_base, |
167 | const data_t *bia_base, data_t *diff_src_base, |
168 | const memory_tracking::grantor_t &scratchpad) const; |
169 | |
170 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
171 | }; |
172 | |
173 | struct gemm_convolution_bwd_weights_t : public primitive_t { |
174 | struct pd_t : public cpu_convolution_bwd_weights_pd_t { |
175 | pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, |
176 | const convolution_fwd_pd_t *hint_fwd_pd) |
177 | : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) |
178 | , jcp_() {} |
179 | |
180 | DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t, |
181 | USE_GLOBAL_SCRATCHPAD); |
182 | |
183 | status_t init(engine_t *engine) { |
184 | bool ok = true && desc()->prop_kind == prop_kind::backward_weights |
185 | && set_default_alg_kind(alg_kind::convolution_direct) |
186 | && expect_data_types(data_type::f32, data_type::f32, |
187 | data_type::f32, data_type::f32, data_type::f32) |
188 | && !has_zero_dim_memory() && attr()->has_default_values(); |
189 | if (!ok) return status::unimplemented; |
190 | |
191 | auto scratchpad = scratchpad_registry().registrar(); |
192 | return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad, |
193 | *desc(), src_md_, diff_weights_md_, diff_dst_md_, |
194 | diff_bias_md_, attr_, dnnl_get_max_threads()); |
195 | } |
196 | |
197 | conv_gemm_conf_t jcp_; |
198 | }; |
199 | |
200 | gemm_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {} |
201 | |
202 | typedef typename prec_traits<data_type::f32>::type data_t; |
203 | |
204 | status_t execute(const exec_ctx_t &ctx) const override { |
205 | const bool is_nspc = pd()->jcp_.is_nspc; |
206 | return is_nspc ? execute_backward_weights_nspc(ctx) |
207 | : execute_backward_weights_ncsp(ctx); |
208 | } |
209 | |
210 | private: |
211 | status_t execute_backward_weights_ncsp(const exec_ctx_t &ctx) const; |
212 | status_t execute_backward_weights_nspc(const exec_ctx_t &ctx) const; |
213 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
214 | }; |
215 | |
216 | } // namespace cpu |
217 | } // namespace impl |
218 | } // namespace dnnl |
219 | |
220 | #endif |
221 | |