1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX2_CONVOLUTION_HPP
18#define CPU_X64_JIT_AVX2_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/memory_tracking.hpp"
23#include "common/primitive.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/cpu_convolution_pd.hpp"
27#include "cpu/x64/cpu_reducer.hpp"
28
29#include "cpu/x64/jit_avx2_conv_kernel_f32.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34namespace x64 {
35
36struct jit_avx2_convolution_fwd_t : public primitive_t {
37 struct pd_t : public cpu_convolution_fwd_pd_t {
38 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
39 const typename pd_t::base_class *hint_fwd_pd)
40 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
41
42 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", jcp_.isa, ""),
43 jit_avx2_convolution_fwd_t);
44
45 status_t init(engine_t *engine) {
46 bool ok = true && is_fwd()
47 && set_default_alg_kind(alg_kind::convolution_direct)
48 && expect_data_types(data_type::f32, data_type::f32,
49 data_type::f32, data_type::f32, data_type::f32)
50 && attr()->has_default_values(
51 primitive_attr_t::skip_mask_t::post_ops,
52 data_type::f32)
53 && !has_zero_dim_memory() && set_default_formats()
54 && attr_.set_default_formats(dst_md(0)) == status::success;
55 if (!ok) return status::unimplemented;
56
57 CHECK(jit_avx2_conv_fwd_kernel_f32::init_conf(
58 jcp_, *desc(), src_md(), weights_md(), dst_md(), *attr()));
59
60 auto scratchpad = scratchpad_registry().registrar();
61 jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);
62
63 return status::success;
64 }
65
66 jit_conv_conf_t jcp_;
67
68 protected:
69 bool set_default_formats() {
70 using namespace format_tag;
71
72 const memory_desc_wrapper src_d(&src_md_);
73 const memory_desc_wrapper dst_d(&dst_md_);
74
75 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
76 const auto dat_tag_ncx = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
77 const auto dat_tag_nCx8c
78 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
79 const auto curr_src_tag = src_d.matches_one_of_tag(
80 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
81 const auto curr_dst_tag = dst_d.matches_one_of_tag(
82 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
83 const auto is_data_layout_nxc
84 = IMPLICATION(curr_src_tag != dat_tag_nxc,
85 src_d.format_kind() == format_kind::any)
86 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
87 dst_d.format_kind() == format_kind::any)
88 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
89
90 const bool flat = IC() < 8;
91 auto src_tag = is_data_layout_nxc
92 ? dat_tag_nxc
93 : flat ? dat_tag_ncx : dat_tag_nCx8c;
94 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
95 auto wei_tag = with_groups()
96 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
97 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
98 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
99 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
100
101 return set_default_formats_common(src_tag, wei_tag, dst_tag);
102 }
103 };
104
105 jit_avx2_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
106
107 typedef typename prec_traits<data_type::f32>::type data_t;
108
109 status_t init(engine_t *engine) override {
110 CHECK(safe_ptr_assign(kernel_,
111 new jit_avx2_conv_fwd_kernel_f32(
112 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
113 return kernel_->create_kernel();
114 }
115
116 status_t execute(const exec_ctx_t &ctx) const override {
117 execute_forward(ctx);
118 return status::success;
119 }
120
121private:
122 void execute_forward(const exec_ctx_t &ctx) const;
123 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
124
125 std::unique_ptr<jit_avx2_conv_fwd_kernel_f32> kernel_;
126};
127
128struct jit_avx2_convolution_bwd_data_t : public primitive_t {
129 struct pd_t : public cpu_convolution_bwd_data_pd_t {
130 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
131 const convolution_fwd_pd_t *hint_fwd_pd)
132 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
133
134 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
135 jit_avx2_convolution_bwd_data_t);
136
137 status_t init(engine_t *engine) {
138 bool ok = true && desc()->prop_kind == prop_kind::backward_data
139 && set_default_alg_kind(alg_kind::convolution_direct)
140 && expect_data_types(data_type::f32, data_type::f32,
141 data_type::undef, data_type::f32, data_type::f32)
142 && attr()->has_default_values() && !has_zero_dim_memory()
143 && set_default_formats();
144 if (!ok) return status::unimplemented;
145
146 status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(jcp_,
147 *desc(), *diff_src_md(), *weights_md(), *diff_dst_md());
148 if (status != status::success) return status;
149
150 auto scratchpad = scratchpad_registry().registrar();
151 jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
152 scratchpad, jcp_);
153
154 return status::success;
155 }
156
157 jit_conv_conf_t jcp_;
158
159 protected:
160 bool set_default_formats() {
161 using namespace format_tag;
162
163 const memory_desc_wrapper diff_src_d(&diff_src_md_);
164 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
165
166 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
167 const auto dat_tag_nCx8c
168 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
169 const auto curr_src_tag
170 = diff_src_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
171 const auto curr_dst_tag
172 = diff_dst_d.matches_one_of_tag(dat_tag_nxc, dat_tag_nCx8c);
173 const auto is_data_layout_nxc
174 = IMPLICATION(curr_src_tag != dat_tag_nxc,
175 diff_src_d.format_kind() == format_kind::any)
176 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
177 diff_dst_d.format_kind() == format_kind::any)
178 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
179
180 auto dat_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
181 auto wei_tag = with_groups()
182 ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
183 : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
184
185 return set_default_formats_common(dat_tag, wei_tag, dat_tag);
186 }
187 };
188
189 jit_avx2_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
190
191 typedef typename prec_traits<data_type::f32>::type data_t;
192
193 status_t init(engine_t *engine) override {
194 CHECK(safe_ptr_assign(
195 kernel_, new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_)));
196 return kernel_->create_kernel();
197 }
198
199 status_t execute(const exec_ctx_t &ctx) const override {
200 execute_backward_data(ctx);
201 return status::success;
202 }
203
204private:
205 void execute_backward_data(const exec_ctx_t &ctx) const;
206 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
207
208 std::unique_ptr<jit_avx2_conv_bwd_data_kernel_f32> kernel_;
209};
210
211struct jit_avx2_convolution_bwd_weights_t : public primitive_t {
212 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
213 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
214 const convolution_fwd_pd_t *hint_fwd_pd)
215 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
216 , jcp_() {}
217
218 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
219 jit_avx2_convolution_bwd_weights_t);
220
221 status_t init(engine_t *engine) {
222 bool ok = true && desc()->prop_kind == prop_kind::backward_weights
223 && set_default_alg_kind(alg_kind::convolution_direct)
224 && expect_data_types(data_type::f32, data_type::f32,
225 data_type::f32, data_type::f32, data_type::f32)
226 && attr()->has_default_values() && !has_zero_dim_memory()
227 && set_default_formats();
228 if (!ok) return status::unimplemented;
229
230 status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
231 jcp_, *desc(), *src_md(), *diff_weights_md(),
232 *diff_dst_md());
233 if (status != status::success) return status;
234
235 init_balancers();
236
237 auto scratchpad = scratchpad_registry().registrar();
238 jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
239 scratchpad, jcp_);
240
241 auto reducer_bia_scratchpad = memory_tracking::registrar_t(
242 scratchpad, memory_tracking::names::prefix_reducer_bia);
243 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
244
245 auto reducer_wei_scratchpad = memory_tracking::registrar_t(
246 scratchpad, memory_tracking::names::prefix_reducer_wei);
247 reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
248
249 return status::success;
250 }
251
252 jit_conv_conf_t jcp_;
253 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
254 cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
255
256 protected:
257 bool set_default_formats() {
258 using namespace format_tag;
259 const bool flat = IC() == 3;
260
261 const memory_desc_wrapper src_d(&src_md_);
262 const memory_desc_wrapper diff_dst_d(&diff_dst_md_);
263
264 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
265 const auto dat_tag_ncx = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
266 const auto dat_tag_nCx8c
267 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
268 const auto curr_src_tag = src_d.matches_one_of_tag(
269 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
270 const auto curr_dst_tag = diff_dst_d.matches_one_of_tag(
271 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
272 const auto is_data_layout_nxc
273 = IMPLICATION(curr_src_tag != dat_tag_nxc,
274 src_d.format_kind() == format_kind::any)
275 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
276 diff_dst_d.format_kind() == format_kind::any)
277 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
278
279 auto src_tag = is_data_layout_nxc
280 ? dat_tag_nxc
281 : flat ? dat_tag_ncx : dat_tag_nCx8c;
282 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
283 auto wei_tag = with_groups()
284 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
285 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
286 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
287 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
288
289 return set_default_formats_common(src_tag, wei_tag, dst_tag);
290 }
291
292 private:
293 void init_balancers() {
294 const int max_threads = dnnl_get_max_threads();
295 const size_t max_buffer_size = 1 << 21; /* just a heuristic */
296
297 if (with_bias()) {
298 reducer_bia_conf_.init(reduce_balancer_t(max_threads,
299 jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
300 max_buffer_size, true));
301 }
302
303 reducer_wei_conf_.init(reduce_balancer_t(max_threads,
304 jcp_.kd * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block,
305 jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc, jcp_.mb * jcp_.od,
306 max_buffer_size, true));
307 }
308 };
309
310 jit_avx2_convolution_bwd_weights_t(const pd_t *apd) : primitive_t(apd) {}
311
312 typedef typename prec_traits<data_type::f32>::type data_t;
313
314 status_t init(engine_t *engine) override {
315 CHECK(safe_ptr_assign(
316 kernel_, new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_)));
317 CHECK(safe_ptr_assign(reducer_bias_,
318 new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_)));
319 CHECK(safe_ptr_assign(reducer_weights_,
320 new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_)));
321 CHECK(kernel_->create_kernel());
322 CHECK(reducer_weights_->create_kernel());
323 CHECK(reducer_bias_->create_kernel());
324 return status::success;
325 }
326
327 status_t execute(const exec_ctx_t &ctx) const override {
328 execute_backward_weights(ctx);
329 return status::success;
330 }
331
332private:
333 void execute_backward_weights(const exec_ctx_t &ctx) const;
334 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
335
336 std::unique_ptr<jit_avx2_conv_bwd_weights_kernel_f32> kernel_;
337 std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_weights_,
338 reducer_bias_;
339};
340
341} // namespace x64
342} // namespace cpu
343} // namespace impl
344} // namespace dnnl
345
346#endif
347
348// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
349