1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_GEMM_CONVOLUTION_HPP
18#define CPU_GEMM_CONVOLUTION_HPP
19
20#include "common/broadcast_strategy.hpp"
21#include "common/c_types_map.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24
25#include "cpu/cpu_convolution_pd.hpp"
26#include "cpu/gemm/gemm.hpp"
27#include "cpu/gemm_convolution_utils.hpp"
28#include "cpu/primitive_attr_postops.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34struct gemm_convolution_fwd_t : public primitive_t {
35 struct pd_t : public cpu_convolution_fwd_pd_t {
36 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
37 const typename pd_t::base_class *hint_fwd_pd)
38 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
39
40 DECLARE_COMMON_PD_T(
41 GEMM_IMPL_STR, gemm_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD);
42
43 status_t init(engine_t *engine) {
44 using namespace data_type;
45
46 bool ok = is_fwd()
47 && set_default_alg_kind(alg_kind::convolution_direct)
48 && expect_data_types(f32, f32, f32, f32, f32)
49 && !has_zero_dim_memory()
50 && attr()->has_default_values(
51 primitive_attr_t::skip_mask_t::post_ops, f32)
52 && post_ops_ok();
53 if (!ok) return status::unimplemented;
54
55 auto scratchpad = scratchpad_registry().registrar();
56 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
57 *desc(), src_md_, weights_md_, dst_md_, bias_md_, attr_,
58 dnnl_get_max_threads());
59 }
60
61 conv_gemm_conf_t jcp_;
62
63 protected:
64 bool post_ops_ok() const {
65 auto const &po = attr()->post_ops_;
66 auto is_eltwise
67 = [&](int idx) { return po.entry_[idx].is_eltwise(); };
68 auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
69 auto is_binary
70 = [&](int idx) { return po.entry_[idx].is_binary(); };
71
72 for (int idx = 0; idx < po.len(); idx++) {
73 bool ok = utils::one_of(true, is_sum(idx), is_binary(idx),
74 is_eltwise(idx))
75 && IMPLICATION(is_sum(idx), idx == 0)
76 && IMPLICATION(is_binary(idx),
77 dnnl::impl::get_rhs_arg_broadcasting_strategy(
78 po.entry_[idx].binary.src1_desc,
79 dst_md_,
80 {broadcasting_strategy_t::scalar,
81 broadcasting_strategy_t::
82 per_oc})
83 != broadcasting_strategy_t::
84 unsupported);
85 if (!ok) return false;
86 }
87
88 return true;
89 }
90 };
91
92 gemm_convolution_fwd_t(const pd_t *apd)
93 : primitive_t(apd), post_ops_(nullptr) {}
94
95 status_t init(engine_t *engine) override {
96 const data_t one = 1.0, zero = 0.0;
97 const auto &jcp = pd()->jcp_;
98 beta_ = jcp.with_sum ? one : zero;
99
100 if (jcp.with_eltwise || jcp.with_binary)
101 CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops)));
102 return status::success;
103 }
104
105 typedef typename prec_traits<data_type::f32>::type data_t;
106
107 status_t execute(const exec_ctx_t &ctx) const override {
108 bool is_nspc = pd()->jcp_.is_nspc;
109 return is_nspc ? execute_forward_nspc(ctx) : execute_forward_ncsp(ctx);
110 }
111
112private:
113 status_t execute_forward_ncsp(const exec_ctx_t &ctx) const;
114 status_t execute_forward_nspc(const exec_ctx_t &ctx) const;
115 status_t execute_forward_thr_nspc(const exec_ctx_t &ctx, const int ithr,
116 const int nthr, const data_t *src_base, const data_t *wei_base,
117 const data_t *bia_base, data_t *dst_base,
118 const memory_tracking::grantor_t &scratchpad) const;
119 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
120
121 data_t beta_;
122
123 std::unique_ptr<ref_post_ops_t> post_ops_;
124};
125
126struct gemm_convolution_bwd_data_t : public primitive_t {
127 struct pd_t : public cpu_convolution_bwd_data_pd_t {
128 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
129 const convolution_fwd_pd_t *hint_fwd_pd)
130 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
131
132 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t,
133 USE_GLOBAL_SCRATCHPAD);
134
135 status_t init(engine_t *engine) {
136 bool ok = true && desc()->prop_kind == prop_kind::backward_data
137 && set_default_alg_kind(alg_kind::convolution_direct)
138 && expect_data_types(data_type::f32, data_type::f32,
139 data_type::undef, data_type::f32, data_type::f32)
140 && !has_zero_dim_memory() && attr()->has_default_values();
141 if (!ok) return status::unimplemented;
142
143 auto scratchpad = scratchpad_registry().registrar();
144 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
145 *desc(), diff_src_md_, weights_md_, diff_dst_md_, bias_md_,
146 attr_, dnnl_get_max_threads());
147 }
148
149 conv_gemm_conf_t jcp_;
150 };
151
152 gemm_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
153
154 typedef typename prec_traits<data_type::f32>::type data_t;
155
156 status_t execute(const exec_ctx_t &ctx) const override {
157 bool is_nspc = pd()->jcp_.is_nspc;
158 return is_nspc ? execute_backward_data_nspc(ctx)
159 : execute_backward_data_ncsp(ctx);
160 }
161
162private:
163 status_t execute_backward_data_nspc(const exec_ctx_t &ctx) const;
164 status_t execute_backward_data_ncsp(const exec_ctx_t &ctx) const;
165 status_t execute_backward_data_thr_nspc(const int ithr, const int nthr,
166 const data_t *diff_dst_base, const data_t *wei_base,
167 const data_t *bia_base, data_t *diff_src_base,
168 const memory_tracking::grantor_t &scratchpad) const;
169
170 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
171};
172
173struct gemm_convolution_bwd_weights_t : public primitive_t {
174 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
175 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
176 const convolution_fwd_pd_t *hint_fwd_pd)
177 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
178 , jcp_() {}
179
180 DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t,
181 USE_GLOBAL_SCRATCHPAD);
182
183 status_t init(engine_t *engine) {
184 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
185 && set_default_alg_kind(alg_kind::convolution_direct)
186 && expect_data_types(data_type::f32, data_type::f32,
187 data_type::f32, data_type::f32, data_type::f32)
188 && !has_zero_dim_memory() && attr()->has_default_values();
189 if (!ok) return status::unimplemented;
190
191 auto scratchpad = scratchpad_registry().registrar();
192 return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
193 *desc(), src_md_, diff_weights_md_, diff_dst_md_,
194 diff_bias_md_, attr_, dnnl_get_max_threads());
195 }
196
197 conv_gemm_conf_t jcp_;
198 };
199
200 gemm_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
201
202 typedef typename prec_traits<data_type::f32>::type data_t;
203
204 status_t execute(const exec_ctx_t &ctx) const override {
205 const bool is_nspc = pd()->jcp_.is_nspc;
206 return is_nspc ? execute_backward_weights_nspc(ctx)
207 : execute_backward_weights_ncsp(ctx);
208 }
209
210private:
211 status_t execute_backward_weights_ncsp(const exec_ctx_t &ctx) const;
212 status_t execute_backward_weights_nspc(const exec_ctx_t &ctx) const;
213 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
214};
215
216} // namespace cpu
217} // namespace impl
218} // namespace dnnl
219
220#endif
221