1/*******************************************************************************
2* Copyright 2017-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_HPP
18#define CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_4X3_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/primitive.hpp"
23
24#include "cpu/cpu_convolution_pd.hpp"
25#include "cpu/platform.hpp"
26
27#include "cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34namespace winograd_avx512_core {
35inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
36 const jit_conv_winograd_conf_t &jcp) {
37 using namespace utils;
38 using namespace memory_tracking::names;
39
40 size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
41 size_t V_sz
42 = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles * jcp.jtiles;
43 size_t M_sz
44 = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles * jcp.jtiles;
45
46 switch (jcp.sched_policy) {
47 case WSCHED_DATA_W_SGD:
48 V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
49 * jcp.tile_block_ur * jcp.ic;
50 M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
51 * jcp.tile_block_ur * jcp.oc;
52 break;
53 case WSCHED_WEI_SDGtWo:
54 U_sz = (size_t)jcp.nthr
55 * (alpha * alpha * jcp.oc * (jcp.ic / jcp.nb_ic)
56 + jcp.ic * jcp.oc * jcp.kh * jcp.kw);
57 M_sz = (size_t)jcp.nthr * alpha * alpha
58 * (jcp.ntiles / jcp.tile_block) * (jcp.oc / jcp.nb_oc);
59 V_sz = (size_t)jcp.nthr * alpha * alpha
60 * (jcp.ntiles / jcp.tile_block) * (jcp.ic / jcp.nb_ic);
61 break;
62 case WSCHED_WEI_S_D_Giot_W:
63 U_sz = (jcp.nthr + static_cast<size_t>(1)) * alpha * alpha * jcp.ic
64 * jcp.oc;
65 M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles;
66 V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles;
67 break;
68 default: break;
69 }
70
71 scratchpad.book<float>(key_wino_U, U_sz, PAGE_2M);
72 scratchpad.book<float>(key_wino_V, V_sz, PAGE_2M);
73 scratchpad.book<float>(key_wino_M, M_sz, PAGE_2M);
74
75 if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) {
76 size_t br_sz = (size_t)jcp.nthr * jcp.oc;
77 scratchpad.book<float>(key_conv_bia_reduction, br_sz, PAGE_2M);
78 }
79}
80} // namespace winograd_avx512_core
81
82template <bool is_fwd>
83struct _jit_avx512_core_f32_wino_conv_4x3_t {
84
85 _jit_avx512_core_f32_wino_conv_4x3_t(
86 const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
87 : attr_(attr) {}
88
89protected:
90 void weight_transform_data(
91 const jit_conv_winograd_conf_t &jcp, float *wp, float *twp) const;
92 void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
93 float *inp, float *tinp) const;
94 void input_transform_tileblock_data(int tile_block,
95 const jit_conv_winograd_conf_t &jcp, float *inp, float *tinp) const;
96 void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
97 const post_ops_t &p_ops, float *toutp, float *pout_b,
98 float *bias) const;
99 void output_transform_tileblock_data(int tile_block,
100 const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
101 float *toutp, float *outp, float *bias) const;
102 void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr, float *wei_ptr,
103 float *bias_ptr,
104 const memory_tracking::grantor_t &scratchpad) const;
105 void _execute_data_W_SGD(float *inp_ptr, float *out_ptr, float *wei_ptr,
106 float *bias_ptr,
107 const memory_tracking::grantor_t &scratchpad) const;
108 std::unique_ptr<_jit_avx512_core_f32_wino_conv_4x3_data_kernel> kernel_;
109 const primitive_attr_t *attr_;
110
111private:
112 DNNL_DISALLOW_COPY_AND_ASSIGN(_jit_avx512_core_f32_wino_conv_4x3_t);
113};
114
115struct jit_avx512_core_f32_wino_conv_4x3_fwd_t
116 : _jit_avx512_core_f32_wino_conv_4x3_t<true>,
117 public primitive_t {
118 struct pd_t : public cpu_convolution_fwd_pd_t {
119 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
120 const typename pd_t::base_class *hint_fwd_pd)
121 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
122
123 DECLARE_COMMON_PD_T(
124 JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
125 jit_avx512_core_f32_wino_conv_4x3_fwd_t, USE_GLOBAL_SCRATCHPAD);
126
127 status_t init(engine_t *engine) {
128 using namespace data_type;
129 bool ok = is_fwd()
130 && utils::one_of(desc()->alg_kind,
131 alg_kind::convolution_auto,
132 alg_kind::convolution_winograd)
133 && expect_data_types(f32, f32, f32, f32, f32)
134 && attr()->has_default_values(
135 primitive_attr_t::skip_mask_t::post_ops, f32)
136 && set_default_formats()
137 && attr_.set_default_formats(dst_md(0)) == status::success;
138 if (!ok) return status::unimplemented;
139
140 CHECK(jit_avx512_core_f32_wino_conv_4x3_fwd_kernel::init_conf(
141 jcp_, *desc(), src_md_, weights_md_, dst_md_, *attr()));
142 set_default_alg_kind(alg_kind::convolution_winograd);
143
144 auto scratchpad = scratchpad_registry().registrar();
145 winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
146
147 return status::success;
148 }
149
150 jit_conv_winograd_conf_t jcp_;
151
152 protected:
153 bool set_default_formats() {
154 using namespace format_tag;
155 auto wei_fmt = desc()->prop_kind == prop_kind::forward_training
156 ? (with_groups() ? gOIhw16i16o : OIhw16i16o)
157 : any;
158 return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
159 }
160 };
161
162 jit_avx512_core_f32_wino_conv_4x3_fwd_t(const pd_t *apd)
163 : _jit_avx512_core_f32_wino_conv_4x3_t<true>(apd->jcp_, apd->attr())
164 , primitive_t(apd) {}
165 ~jit_avx512_core_f32_wino_conv_4x3_fwd_t() = default;
166
167 typedef typename prec_traits<data_type::f32>::type data_t;
168
169 status_t init(engine_t *engine) override {
170 CHECK(safe_ptr_assign(kernel_,
171 new _jit_avx512_core_f32_wino_conv_4x3_data_kernel(
172 pd()->jcp_)));
173 return kernel_->create_kernel();
174 }
175
176 status_t execute(const exec_ctx_t &ctx) const override {
177 auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
178 auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS);
179 auto bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
180 auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
181
182 auto scratchpad = ctx.get_scratchpad_grantor();
183
184 switch ((pd()->jcp_).sched_policy) {
185 case WSCHED_DATA_W_S_G_D:
186 this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
187 (float *)bias, scratchpad);
188 break;
189 case WSCHED_DATA_W_SGD:
190 this->_execute_data_W_SGD((float *)src, dst, (float *)weights,
191 (float *)bias, scratchpad);
192 break;
193 default: break;
194 }
195 return status::success;
196 }
197
198private:
199 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
200};
201
202struct jit_avx512_core_f32_wino_conv_4x3_bwd_data_t
203 : _jit_avx512_core_f32_wino_conv_4x3_t<false>,
204 public primitive_t {
205 struct pd_t : public cpu_convolution_bwd_data_pd_t {
206 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
207 const convolution_fwd_pd_t *hint_fwd_pd)
208 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
209
210 DECLARE_COMMON_PD_T(
211 JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
212 jit_avx512_core_f32_wino_conv_4x3_bwd_data_t,
213 USE_GLOBAL_SCRATCHPAD);
214
215 status_t init(engine_t *engine) {
216 bool ok = true && dnnl_thr_syncable()
217 && desc()->prop_kind == prop_kind::backward_data
218 && utils::one_of(desc()->alg_kind,
219 alg_kind::convolution_auto,
220 alg_kind::convolution_winograd)
221 && expect_data_types(data_type::f32, data_type::f32,
222 data_type::undef, data_type::f32, data_type::f32)
223 && attr()->has_default_values() && set_default_formats();
224 if (!ok) return status::unimplemented;
225
226 status_t status
227 = jit_avx512_core_f32_wino_conv_4x3_bwd_data_kernel ::
228 init_conf(jcp_, *desc(), *diff_src_md(),
229 *weights_md(), *diff_dst_md());
230 if (status != status::success) return status;
231 set_default_alg_kind(alg_kind::convolution_winograd);
232
233 auto scratchpad = scratchpad_registry().registrar();
234 winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
235
236 return status;
237 }
238
239 jit_conv_winograd_conf_t jcp_;
240
241 protected:
242 bool set_default_formats() {
243 using namespace format_tag;
244 auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
245 return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
246 }
247 };
248
249 jit_avx512_core_f32_wino_conv_4x3_bwd_data_t(const pd_t *apd)
250 : _jit_avx512_core_f32_wino_conv_4x3_t<false>(apd->jcp_, apd->attr())
251 , primitive_t(apd) {}
252
253 typedef typename prec_traits<data_type::f32>::type data_t;
254
255 status_t init(engine_t *engine) override {
256 CHECK(safe_ptr_assign(kernel_,
257 new _jit_avx512_core_f32_wino_conv_4x3_data_kernel(
258 pd()->jcp_)));
259 return kernel_->create_kernel();
260 }
261
262 status_t execute(const exec_ctx_t &ctx) const override {
263 auto diff_dst = CTX_IN_MEM(const float *, DNNL_ARG_DIFF_DST);
264 auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS);
265 auto diff_src = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_SRC);
266
267 auto scratchpad = ctx.get_scratchpad_grantor();
268
269 switch ((pd()->jcp_).sched_policy) {
270 case WSCHED_DATA_W_S_G_D:
271 this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
272 (float *)weights, NULL, scratchpad);
273 break;
274
275 case WSCHED_DATA_W_SGD:
276 this->_execute_data_W_SGD((float *)diff_dst, diff_src,
277 (float *)weights, NULL, scratchpad);
278 break;
279
280 default: break;
281 }
282
283 return status::success;
284 }
285
286private:
287 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
288};
289
290struct jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t : public primitive_t {
291 struct pd_t : public cpu_convolution_bwd_weights_pd_t {
292 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
293 const convolution_fwd_pd_t *hint_fwd_pd)
294 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
295 , jcp_() {}
296
297 DECLARE_COMMON_PD_T(
298 JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
299 jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t,
300 USE_GLOBAL_SCRATCHPAD);
301
302 status_t init(engine_t *engine) {
303 bool ok = true && dnnl_thr_syncable()
304 && desc()->prop_kind == prop_kind::backward_weights
305 && utils::one_of(desc()->alg_kind,
306 alg_kind::convolution_auto,
307 alg_kind::convolution_winograd)
308 && expect_data_types(data_type::f32, data_type::f32,
309 data_type::f32, data_type::f32, data_type::f32)
310 && attr()->has_default_values() && set_default_formats();
311 if (!ok) return status::unimplemented;
312
313 status_t status
314 = jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel::
315 init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
316 *diff_weights_md());
317 if (status != status::success) return status;
318 set_default_alg_kind(alg_kind::convolution_winograd);
319
320 auto scratchpad = scratchpad_registry().registrar();
321 winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
322
323 return status;
324 }
325
326 jit_conv_winograd_conf_t jcp_;
327
328 protected:
329 bool set_default_formats() {
330 using namespace format_tag;
331 auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
332 return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
333 }
334 };
335
336 jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t(const pd_t *apd)
337 : primitive_t(apd) {}
338
339 typedef typename prec_traits<data_type::f32>::type data_t;
340
341 status_t init(engine_t *engine) override {
342 CHECK(safe_ptr_assign(kernel_,
343 new jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel(
344 pd()->jcp_)));
345 return kernel_->create_kernel();
346 }
347
348 status_t execute(const exec_ctx_t &ctx) const override {
349 auto diff_dst = CTX_IN_MEM(const float *, DNNL_ARG_DIFF_DST);
350 auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
351 auto diff_weights = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_WEIGHTS);
352 auto diff_bias = CTX_OUT_MEM(float *, DNNL_ARG_DIFF_BIAS);
353
354 switch (kernel_->jcp.sched_policy) {
355 case WSCHED_WEI_SDGtWo:
356 _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights,
357 diff_bias, ctx.get_scratchpad_grantor());
358 break;
359 case WSCHED_WEI_S_D_Giot_W:
360 _execute_backward_weights_S_D_Giot_W(src, diff_dst,
361 diff_weights, diff_bias, ctx.get_scratchpad_grantor());
362 break;
363 default: assert(kernel_->jcp.sched_policy != WSCHED_INVALID); break;
364 }
365 return status::success;
366 }
367
368private:
369 void _execute_backward_weights_SDGtWo(const float *src,
370 const float *diff_dst, float *diff_weights, float *diff_bias,
371 const memory_tracking::grantor_t &scratchpad) const;
372 void _execute_backward_weights_S_D_Giot_W(const float *src,
373 const float *diff_dst, float *diff_weights, float *diff_bias,
374 const memory_tracking::grantor_t &scratchpad) const;
375
376 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
377 std::unique_ptr<jit_avx512_core_f32_wino_conv_4x3_bwd_weights_kernel>
378 kernel_;
379};
380
381} // namespace x64
382} // namespace cpu
383} // namespace impl
384} // namespace dnnl
385
386#endif
387
388// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
389