1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_NCHW_POOLING_HPP |
18 | #define CPU_NCHW_POOLING_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/bfloat16.hpp" |
23 | #include "common/c_types_map.hpp" |
24 | #include "common/dnnl_thread.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/cpu_pooling_pd.hpp" |
30 | #include "cpu/platform.hpp" |
31 | #include "cpu/primitive_attr_postops.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | |
37 | template <data_type_t d_type> |
38 | struct nchw_pooling_fwd_t : public primitive_t { |
39 | struct pd_t : public cpu_pooling_fwd_pd_t { |
40 | using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t; |
41 | |
42 | DECLARE_COMMON_PD_T("simple_nchw:any" , nchw_pooling_fwd_t); |
43 | |
44 | status_t init(engine_t *engine) { |
45 | const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3, |
46 | format_tag::ncw, format_tag::nchw, format_tag::ncdhw); |
47 | |
48 | const bool ok = is_fwd() |
49 | && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, |
50 | alg_kind::pooling_avg_include_padding, |
51 | alg_kind::pooling_avg_exclude_padding) |
52 | && utils::everyone_is( |
53 | d_type, src_md()->data_type, dst_md()->data_type) |
54 | && platform::has_data_type_support(d_type) |
55 | && !has_zero_dim_memory() && !is_dilated() |
56 | && attr()->has_default_values( |
57 | primitive_attr_t::skip_mask_t::post_ops, d_type) |
58 | && set_default_params() == status::success |
59 | && memory_desc_matches_tag(*src_md(), desired_fmt_tag) |
60 | && memory_desc_matches_tag(*dst_md(), desired_fmt_tag) |
61 | && attr_.set_default_formats(dst_md(0)) == status::success; |
62 | if (!ok) return status::unimplemented; |
63 | |
64 | const bool is_training |
65 | = desc_.prop_kind == prop_kind::forward_training; |
66 | if (desc()->alg_kind == alg_kind::pooling_max && is_training) |
67 | init_default_ws(); |
68 | |
69 | init_scratchpad(); |
70 | |
71 | return status::success; |
72 | } |
73 | |
74 | private: |
75 | void init_scratchpad() { |
76 | using namespace memory_tracking::names; |
77 | if (src_md()->data_type != data_type::f32) { |
78 | const size_t src_sz_ = ID() * IH() * IW() * IC() * MB(); |
79 | auto scratchpad = scratchpad_registry().registrar(); |
80 | scratchpad.template book<float>(key_pool_src_bf16cvt, src_sz_); |
81 | } |
82 | } |
83 | }; |
84 | |
85 | nchw_pooling_fwd_t(const pd_t *apd); |
86 | |
87 | using data_t = typename prec_traits<d_type>::type; |
88 | |
89 | status_t execute(const exec_ctx_t &ctx) const override { |
90 | return execute_forward(ctx); |
91 | } |
92 | |
93 | private: |
94 | status_t execute_forward(const exec_ctx_t &ctx) const; |
95 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
96 | const ref_post_ops_t ref_post_ops_; |
97 | }; |
98 | |
99 | template <data_type_t d_type> |
100 | struct nchw_pooling_bwd_t : public primitive_t { |
101 | struct pd_t : public cpu_pooling_bwd_pd_t { |
102 | using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t; |
103 | |
104 | DECLARE_COMMON_PD_T("simple_nchw:any" , nchw_pooling_bwd_t); |
105 | |
106 | status_t init(engine_t *engine) { |
107 | const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3, |
108 | format_tag::ncw, format_tag::nchw, format_tag::ncdhw); |
109 | |
110 | using namespace prop_kind; |
111 | using namespace alg_kind; |
112 | bool ok = !is_fwd() |
113 | && utils::one_of(desc()->alg_kind, alg_kind::pooling_max, |
114 | alg_kind::pooling_avg_include_padding, |
115 | alg_kind::pooling_avg_exclude_padding) |
116 | && utils::everyone_is(d_type, diff_dst_md()->data_type, |
117 | diff_src_md()->data_type) |
118 | && platform::has_data_type_support(d_type) |
119 | && !has_zero_dim_memory() |
120 | && set_default_params() == status::success |
121 | && attr()->has_default_values() |
122 | && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag) |
123 | && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag) |
124 | && !is_dilated(); |
125 | if (!ok) return status::unimplemented; |
126 | |
127 | if (desc()->alg_kind == pooling_max) { |
128 | bool ws_ok |
129 | = true && hint_fwd_pd_ && hint_fwd_pd_->workspace_md(); |
130 | if (!ws_ok) return status::unimplemented; |
131 | |
132 | const auto &ws_blk |
133 | = hint_fwd_pd_->workspace_md()->format_desc.blocking; |
134 | ws_ok = ws_ok && ws_blk.inner_nblks <= 1 |
135 | && IMPLICATION(ws_blk.inner_nblks == 1, |
136 | ws_blk.inner_idxs[0] == 1); |
137 | if (!ws_ok) return status::unimplemented; |
138 | |
139 | ws_md_ = *hint_fwd_pd_->workspace_md(); |
140 | } |
141 | |
142 | nthr_ = dnnl_get_max_threads(); |
143 | calculate_channel_block_size(); |
144 | init_scratchpad(); |
145 | |
146 | return status::success; |
147 | } |
148 | |
149 | dim_t channel_block_size_; |
150 | int nthr_; // To not exceed the limit in execute used for set up. |
151 | |
152 | private: |
153 | void init_scratchpad() { |
154 | using namespace memory_tracking::names; |
155 | if (diff_dst_md()->data_type != data_type::f32) { |
156 | size_t dst_sz_ = OD() * OH() * OW(); |
157 | size_t src_sz_ = ID() * IH() * IW(); |
158 | auto scratchpad = scratchpad_registry().registrar(); |
159 | |
160 | scratchpad.template book<float>(key_pool_src_bf16cvt, |
161 | src_sz_ * nthr_ * channel_block_size_); |
162 | scratchpad.template book<float>(key_pool_dst_bf16cvt, |
163 | dst_sz_ * nthr_ * channel_block_size_); |
164 | } |
165 | } |
166 | |
167 | void calculate_channel_block_size() { |
168 | // calculate channels block size at which the data fits into half |
169 | // of L1, it allows to improve performance for problems with small |
170 | // spatial |
171 | dim_t dst_sz_ = OD() * OH() * OW(); |
172 | dim_t src_sz_ = ID() * IH() * IW(); |
173 | dim_t C_per_thr = nstl::min(MB() * IC() / nthr_, IC()); |
174 | const dim_t max_block_size |
175 | = platform::get_per_core_cache_size(1) / 2; |
176 | dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16 |
177 | channel_block_size_ = nstl::max( |
178 | nstl::min(C_per_thr, max_block_size / data_size_per_ch), |
179 | (dim_t)1); |
180 | } |
181 | }; |
182 | |
183 | nchw_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
184 | typedef typename prec_traits<d_type>::type data_t; |
185 | |
186 | status_t execute(const exec_ctx_t &ctx) const override { |
187 | return execute_backward(ctx); |
188 | } |
189 | |
190 | private: |
191 | status_t execute_backward(const exec_ctx_t &ctx) const; |
192 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
193 | }; |
194 | |
195 | } // namespace cpu |
196 | } // namespace impl |
197 | } // namespace dnnl |
198 | |
199 | #endif |
200 | |
201 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
202 | |