1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_POOLING_HPP
18#define CPU_X64_JIT_UNI_POOLING_HPP
19
20#include <assert.h>
21#include <memory>
22
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_pooling_pd.hpp"
30#include "cpu/x64/jit_uni_pool_kernel.hpp"
31#include "cpu/x64/jit_uni_reorder.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36namespace x64 {
37
38namespace jit_uni_pooling_utils {
39struct trans_wrapper_t;
40struct trans_context_t;
41} // namespace jit_uni_pooling_utils
42
43template <cpu_isa_t isa, impl::data_type_t d_type>
44struct jit_uni_pooling_fwd_t : public primitive_t {
45 struct pd_t : public cpu_pooling_fwd_pd_t {
46 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
47
48 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", jpp_.isa, ""),
49 jit_uni_pooling_fwd_t);
50
51 status_t init(engine_t *engine) {
52 using namespace utils;
53
54 const bool ok = is_fwd() && !has_zero_dim_memory()
55 && everyone_is(
56 d_type, src_md()->data_type, dst_md()->data_type)
57 && attr()->has_default_values(
58 primitive_attr_t::skip_mask_t::post_ops, d_type)
59 && !is_dilated() && set_default_params() == status::success;
60 if (!ok) return status::unimplemented;
61
62 const bool is_training
63 = desc_.prop_kind == prop_kind::forward_training;
64 if (desc()->alg_kind == alg_kind::pooling_max && is_training)
65 init_default_ws();
66
67 auto scratchpad = scratchpad_registry().registrar();
68
69 CHECK(jit_uni_pool_kernel<isa>::init_conf(
70 jpp_, scratchpad, attr_, this));
71
72 return status::success;
73 }
74
75 jit_pool_conf_t jpp_;
76 };
77
78 explicit jit_uni_pooling_fwd_t(const pd_t *apd);
79 jit_uni_pooling_fwd_t(jit_uni_pooling_fwd_t &&) = default;
80 jit_uni_pooling_fwd_t &operator=(jit_uni_pooling_fwd_t &&) = default;
81 ~jit_uni_pooling_fwd_t();
82
83 using data_t = typename prec_traits<d_type>::type;
84
85 status_t init(engine_t *engine) override;
86
87 status_t execute(const exec_ctx_t &ctx) const override {
88 auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
89 auto dst = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
90 auto ws = CTX_OUT_MEM(char *, DNNL_ARG_WORKSPACE);
91
92 if (pd()->ndims() == 5)
93 execute_forward_3d(src, dst, ws, ctx);
94 else
95 execute_forward(src, dst, ws, ctx);
96
97 return status::success;
98 }
99
100private:
101 void execute_forward(const data_t *src, data_t *dst, char *indices,
102 const exec_ctx_t &ctx) const;
103 void execute_forward_3d(const data_t *src, data_t *dst, char *indices,
104 const exec_ctx_t &ctx) const;
105 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
106 status_t init_ncsp_trans_ctx();
107
108 std::unique_ptr<jit_uni_pool_kernel<isa>> kernel_;
109 std::unique_ptr<jit_uni_pooling_utils::trans_context_t> trans_ctx_;
110 static constexpr data_type_t wsp_dt_ = data_type::f32;
111};
112
113template <cpu_isa_t isa, impl::data_type_t d_type>
114struct jit_uni_pooling_bwd_t : public primitive_t {
115 struct pd_t : public cpu_pooling_bwd_pd_t {
116 using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
117
118 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", jpp_.isa, ""),
119 jit_uni_pooling_bwd_t);
120
121 status_t init(engine_t *engine) {
122 using namespace utils;
123
124 const bool ok = true && set_default_params() == status::success
125 && !is_fwd() && !has_zero_dim_memory()
126 && everyone_is(d_type, diff_src_md()->data_type,
127 diff_dst_md()->data_type)
128 && attr()->has_default_values() && !is_dilated();
129 if (!ok) return status::unimplemented;
130
131 if (desc()->alg_kind == alg_kind::pooling_max) {
132 init_default_ws();
133 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
134 }
135
136 auto scratchpad = scratchpad_registry().registrar();
137
138 CHECK(jit_uni_pool_kernel<isa>::init_conf(
139 jpp_, scratchpad, attr_, this));
140
141 return status::success;
142 }
143
144 jit_pool_conf_t jpp_;
145 };
146
147 explicit jit_uni_pooling_bwd_t(const pd_t *apd);
148 jit_uni_pooling_bwd_t(jit_uni_pooling_bwd_t &&) = default;
149 jit_uni_pooling_bwd_t &operator=(jit_uni_pooling_bwd_t &&) = default;
150 ~jit_uni_pooling_bwd_t();
151
152 using data_t = typename prec_traits<d_type>::type;
153
154 status_t init(engine_t *engine) override;
155
156 status_t execute(const exec_ctx_t &ctx) const override {
157 auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
158 auto ws = CTX_IN_MEM(const char *, DNNL_ARG_WORKSPACE);
159 auto diff_src = CTX_OUT_MEM(data_t *, DNNL_ARG_DIFF_SRC);
160
161 if (pd()->ndims() == 5)
162 execute_backward_3d(diff_dst, ws, diff_src, ctx);
163 else
164 execute_backward(diff_dst, ws, diff_src, ctx);
165
166 return status::success;
167 }
168
169private:
170 void execute_backward(const data_t *diff_dst, const char *indices,
171 data_t *diff_src, const exec_ctx_t &ctx) const;
172 void execute_backward_3d(const data_t *diff_dst, const char *indices,
173 data_t *diff_src, const exec_ctx_t &ctx) const;
174 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
175 status_t init_ncsp_trans_ctx();
176
177 std::unique_ptr<jit_uni_pool_kernel<isa>> kernel_;
178 std::unique_ptr<jit_uni_pooling_utils::trans_context_t> trans_ctx_;
179 static constexpr data_type_t wsp_dt_ = data_type::f32;
180};
181
182} // namespace x64
183} // namespace cpu
184} // namespace impl
185} // namespace dnnl
186
187#endif
188
189// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
190