1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_2X3_HPP
18#define CPU_X64_JIT_AVX512_CORE_F32_WINO_CONV_2X3_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/primitive.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/cpu_convolution_pd.hpp"
29#include "cpu/cpu_primitive.hpp"
30#include "cpu/platform.hpp"
31
32#include "cpu/x64/jit_generator.hpp"
33#include "cpu/x64/jit_primitive_conf.hpp"
34
35namespace dnnl {
36namespace impl {
37namespace cpu {
38namespace x64 {
39
40struct jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t;
41struct jit_avx512_core_f32_wino_conv_2x3_src_trans_t;
42struct jit_avx512_core_f32_wino_conv_2x3_dst_trans_t;
43
44struct jit_avx512_core_f32_wino_conv_2x3_fwd_t : public primitive_t {
45 struct pd_t : public cpu_convolution_fwd_pd_t {
46 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
47 const typename pd_t::base_class *hint_fwd_pd)
48 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
49
50 DECLARE_COMMON_PD_T(
51 JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""),
52 jit_avx512_core_f32_wino_conv_2x3_fwd_t);
53
54 status_t init(engine_t *engine) {
55 using namespace data_type;
56 bool ok = desc()->prop_kind == prop_kind::forward_inference
57 && utils::one_of(desc()->alg_kind,
58 alg_kind::convolution_auto,
59 alg_kind::convolution_winograd)
60 && expect_data_types(f32, f32, f32, f32, f32)
61 && attr()->has_default_values(
62 primitive_attr_t::skip_mask_t::post_ops, f32)
63 && set_default_formats()
64 && attr_.set_default_formats(dst_md(0)) == status::success;
65 if (!ok) return status::unimplemented;
66
67 memory_desc_t expect_wei_md = *weights_md();
68 CHECK(jit_conf(expect_wei_md));
69 set_default_alg_kind(alg_kind::convolution_winograd);
70
71 if (weights_md_.format_kind == format_kind::any)
72 weights_md_ = expect_wei_md;
73 if (weights_md_ != expect_wei_md) return status::unimplemented;
74
75 init_scratchpad();
76
77 return status::success;
78 }
79
80 jit_conv_conf_2x3_wino_t jcp_;
81
82 protected:
83 status_t jit_conf(memory_desc_t &expect_wei_md);
84
85 void init_scratchpad() {
86 using namespace memory_tracking::names;
87
88 auto scratchpad = scratchpad_registry().registrar();
89
90 int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb;
91
92 size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr;
93 scratchpad.book<float>(key_wino_V, V_sz, PAGE_4K);
94
95 size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr;
96 scratchpad.book<float>(key_wino_M, M_sz, PAGE_4K);
97
98 if (wants_padded_bias()) {
99 assert(jcp_.ngroups == 1);
100 scratchpad.book<float>(key_conv_padded_bias, jcp_.oc);
101 }
102 }
103
104 bool set_default_formats() {
105 using namespace format_tag;
106 return set_default_formats_common(nChw16c, any, nChw16c);
107 }
108 };
109
110 jit_avx512_core_f32_wino_conv_2x3_fwd_t(const pd_t *apd);
111 ~jit_avx512_core_f32_wino_conv_2x3_fwd_t();
112
113 status_t init(engine_t *engine) override;
114
115 status_t execute(const exec_ctx_t &ctx) const override {
116 auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC);
117 auto wei = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS);
118 auto bia = CTX_IN_MEM(const float *, DNNL_ARG_BIAS);
119 auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST);
120
121 if (pd()->jcp_.small_mb)
122 execute_forward_small_mb(
123 src, wei, bia, dst, ctx.get_scratchpad_grantor());
124 else
125 execute_forward_mbN(
126 src, wei, bia, dst, ctx.get_scratchpad_grantor());
127
128 return status::success;
129 }
130
131private:
132 void execute_forward_small_mb(const float *src, const float *wei,
133 const float *bia, float *dst,
134 const memory_tracking::grantor_t &scratchpad) const;
135 void execute_forward_mbN(const float *src, const float *wei,
136 const float *bia, float *dst,
137 const memory_tracking::grantor_t &scratchpad) const;
138 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
139
140 std::unique_ptr<jit_avx512_core_f32_wino_conv_2x3_fwd_ker_t> kernel_;
141 std::unique_ptr<jit_avx512_core_f32_wino_conv_2x3_src_trans_t> src_trans_;
142 std::unique_ptr<jit_avx512_core_f32_wino_conv_2x3_dst_trans_t> dst_trans_;
143};
144
145} // namespace x64
146} // namespace cpu
147} // namespace impl
148} // namespace dnnl
149
150#endif
151