1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_UNI_POOLING_HPP |
18 | #define CPU_X64_JIT_UNI_POOLING_HPP |
19 | |
20 | #include <assert.h> |
21 | #include <memory> |
22 | |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_pooling_pd.hpp" |
30 | #include "cpu/x64/jit_uni_pool_kernel.hpp" |
31 | #include "cpu/x64/jit_uni_reorder.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | namespace x64 { |
37 | |
38 | namespace jit_uni_pooling_utils { |
39 | struct trans_wrapper_t; |
40 | struct trans_context_t; |
41 | } // namespace jit_uni_pooling_utils |
42 | |
43 | template <cpu_isa_t isa, impl::data_type_t d_type> |
44 | struct jit_uni_pooling_fwd_t : public primitive_t { |
45 | struct pd_t : public cpu_pooling_fwd_pd_t { |
46 | using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; |
47 | |
48 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , jpp_.isa, "" ), |
49 | jit_uni_pooling_fwd_t); |
50 | |
51 | status_t init(engine_t *engine) { |
52 | using namespace utils; |
53 | |
54 | const bool ok = is_fwd() && !has_zero_dim_memory() |
55 | && everyone_is( |
56 | d_type, src_md()->data_type, dst_md()->data_type) |
57 | && attr()->has_default_values( |
58 | primitive_attr_t::skip_mask_t::post_ops, d_type) |
59 | && !is_dilated() && set_default_params() == status::success; |
60 | if (!ok) return status::unimplemented; |
61 | |
62 | const bool is_training |
63 | = desc_.prop_kind == prop_kind::forward_training; |
64 | if (desc()->alg_kind == alg_kind::pooling_max && is_training) |
65 | init_default_ws(); |
66 | |
67 | auto scratchpad = scratchpad_registry().registrar(); |
68 | |
69 | CHECK(jit_uni_pool_kernel<isa>::init_conf( |
70 | jpp_, scratchpad, attr_, this)); |
71 | |
72 | return status::success; |
73 | } |
74 | |
75 | jit_pool_conf_t jpp_; |
76 | }; |
77 | |
78 | explicit jit_uni_pooling_fwd_t(const pd_t *apd); |
79 | jit_uni_pooling_fwd_t(jit_uni_pooling_fwd_t &&) = default; |
80 | jit_uni_pooling_fwd_t &operator=(jit_uni_pooling_fwd_t &&) = default; |
81 | ~jit_uni_pooling_fwd_t(); |
82 | |
83 | using data_t = typename prec_traits<d_type>::type; |
84 | |
85 | status_t init(engine_t *engine) override; |
86 | |
87 | status_t execute(const exec_ctx_t &ctx) const override { |
88 | auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
89 | auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
90 | auto ws = CTX_OUT_MEM(char *, DNNL_ARG_WORKSPACE); |
91 | |
92 | if (pd()->ndims() == 5) |
93 | execute_forward_3d(src, dst, ws, ctx); |
94 | else |
95 | execute_forward(src, dst, ws, ctx); |
96 | |
97 | return status::success; |
98 | } |
99 | |
100 | private: |
101 | void execute_forward(const data_t *src, data_t *dst, char *indices, |
102 | const exec_ctx_t &ctx) const; |
103 | void execute_forward_3d(const data_t *src, data_t *dst, char *indices, |
104 | const exec_ctx_t &ctx) const; |
105 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
106 | status_t init_ncsp_trans_ctx(); |
107 | |
108 | std::unique_ptr<jit_uni_pool_kernel<isa>> kernel_; |
109 | std::unique_ptr<jit_uni_pooling_utils::trans_context_t> trans_ctx_; |
110 | static constexpr data_type_t wsp_dt_ = data_type::f32; |
111 | }; |
112 | |
113 | template <cpu_isa_t isa, impl::data_type_t d_type> |
114 | struct jit_uni_pooling_bwd_t : public primitive_t { |
115 | struct pd_t : public cpu_pooling_bwd_pd_t { |
116 | using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; |
117 | |
118 | DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:" , jpp_.isa, "" ), |
119 | jit_uni_pooling_bwd_t); |
120 | |
121 | status_t init(engine_t *engine) { |
122 | using namespace utils; |
123 | |
124 | const bool ok = true && set_default_params() == status::success |
125 | && !is_fwd() && !has_zero_dim_memory() |
126 | && everyone_is(d_type, diff_src_md()->data_type, |
127 | diff_dst_md()->data_type) |
128 | && attr()->has_default_values() && !is_dilated(); |
129 | if (!ok) return status::unimplemented; |
130 | |
131 | if (desc()->alg_kind == alg_kind::pooling_max) { |
132 | init_default_ws(); |
133 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
134 | } |
135 | |
136 | auto scratchpad = scratchpad_registry().registrar(); |
137 | |
138 | CHECK(jit_uni_pool_kernel<isa>::init_conf( |
139 | jpp_, scratchpad, attr_, this)); |
140 | |
141 | return status::success; |
142 | } |
143 | |
144 | jit_pool_conf_t jpp_; |
145 | }; |
146 | |
147 | explicit jit_uni_pooling_bwd_t(const pd_t *apd); |
148 | jit_uni_pooling_bwd_t(jit_uni_pooling_bwd_t &&) = default; |
149 | jit_uni_pooling_bwd_t &operator=(jit_uni_pooling_bwd_t &&) = default; |
150 | ~jit_uni_pooling_bwd_t(); |
151 | |
152 | using data_t = typename prec_traits<d_type>::type; |
153 | |
154 | status_t init(engine_t *engine) override; |
155 | |
156 | status_t execute(const exec_ctx_t &ctx) const override { |
157 | auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
158 | auto ws = CTX_IN_MEM(const char *, DNNL_ARG_WORKSPACE); |
159 | auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC); |
160 | |
161 | if (pd()->ndims() == 5) |
162 | execute_backward_3d(diff_dst, ws, diff_src, ctx); |
163 | else |
164 | execute_backward(diff_dst, ws, diff_src, ctx); |
165 | |
166 | return status::success; |
167 | } |
168 | |
169 | private: |
170 | void execute_backward(const data_t *diff_dst, const char *indices, |
171 | data_t *diff_src, const exec_ctx_t &ctx) const; |
172 | void execute_backward_3d(const data_t *diff_dst, const char *indices, |
173 | data_t *diff_src, const exec_ctx_t &ctx) const; |
174 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
175 | status_t init_ncsp_trans_ctx(); |
176 | |
177 | std::unique_ptr<jit_uni_pool_kernel<isa>> kernel_; |
178 | std::unique_ptr<jit_uni_pooling_utils::trans_context_t> trans_ctx_; |
179 | static constexpr data_type_t wsp_dt_ = data_type::f32; |
180 | }; |
181 | |
182 | } // namespace x64 |
183 | } // namespace cpu |
184 | } // namespace impl |
185 | } // namespace dnnl |
186 | |
187 | #endif |
188 | |
189 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
190 | |