1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_NCHW_POOLING_HPP
18#define CPU_NCHW_POOLING_HPP
19
20#include <assert.h>
21
22#include "common/bfloat16.hpp"
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_pooling_pd.hpp"
30#include "cpu/platform.hpp"
31#include "cpu/primitive_attr_postops.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37template <data_type_t d_type>
38struct nchw_pooling_fwd_t : public primitive_t {
39 struct pd_t : public cpu_pooling_fwd_pd_t {
40 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
41
42 DECLARE_COMMON_PD_T("simple_nchw:any", nchw_pooling_fwd_t);
43
44 status_t init(engine_t *engine) {
45 const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3,
46 format_tag::ncw, format_tag::nchw, format_tag::ncdhw);
47
48 const bool ok = is_fwd()
49 && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
50 alg_kind::pooling_avg_include_padding,
51 alg_kind::pooling_avg_exclude_padding)
52 && utils::everyone_is(
53 d_type, src_md()->data_type, dst_md()->data_type)
54 && platform::has_data_type_support(d_type)
55 && !has_zero_dim_memory() && !is_dilated()
56 && attr()->has_default_values(
57 primitive_attr_t::skip_mask_t::post_ops, d_type)
58 && set_default_params() == status::success
59 && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
60 && memory_desc_matches_tag(*dst_md(), desired_fmt_tag)
61 && attr_.set_default_formats(dst_md(0)) == status::success;
62 if (!ok) return status::unimplemented;
63
64 const bool is_training
65 = desc_.prop_kind == prop_kind::forward_training;
66 if (desc()->alg_kind == alg_kind::pooling_max && is_training)
67 init_default_ws();
68
69 init_scratchpad();
70
71 return status::success;
72 }
73
74 private:
75 void init_scratchpad() {
76 using namespace memory_tracking::names;
77 if (src_md()->data_type != data_type::f32) {
78 const size_t src_sz_ = ID() * IH() * IW() * IC() * MB();
79 auto scratchpad = scratchpad_registry().registrar();
80 scratchpad.template book<float>(key_pool_src_bf16cvt, src_sz_);
81 }
82 }
83 };
84
85 nchw_pooling_fwd_t(const pd_t *apd);
86
87 using data_t = typename prec_traits<d_type>::type;
88
89 status_t execute(const exec_ctx_t &ctx) const override {
90 return execute_forward(ctx);
91 }
92
93private:
94 status_t execute_forward(const exec_ctx_t &ctx) const;
95 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
96 const ref_post_ops_t ref_post_ops_;
97};
98
99template <data_type_t d_type>
100struct nchw_pooling_bwd_t : public primitive_t {
101 struct pd_t : public cpu_pooling_bwd_pd_t {
102 using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
103
104 DECLARE_COMMON_PD_T("simple_nchw:any", nchw_pooling_bwd_t);
105
106 status_t init(engine_t *engine) {
107 const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3,
108 format_tag::ncw, format_tag::nchw, format_tag::ncdhw);
109
110 using namespace prop_kind;
111 using namespace alg_kind;
112 bool ok = !is_fwd()
113 && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
114 alg_kind::pooling_avg_include_padding,
115 alg_kind::pooling_avg_exclude_padding)
116 && utils::everyone_is(d_type, diff_dst_md()->data_type,
117 diff_src_md()->data_type)
118 && platform::has_data_type_support(d_type)
119 && !has_zero_dim_memory()
120 && set_default_params() == status::success
121 && attr()->has_default_values()
122 && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
123 && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag)
124 && !is_dilated();
125 if (!ok) return status::unimplemented;
126
127 if (desc()->alg_kind == pooling_max) {
128 bool ws_ok
129 = true && hint_fwd_pd_ && hint_fwd_pd_->workspace_md();
130 if (!ws_ok) return status::unimplemented;
131
132 const auto &ws_blk
133 = hint_fwd_pd_->workspace_md()->format_desc.blocking;
134 ws_ok = ws_ok && ws_blk.inner_nblks <= 1
135 && IMPLICATION(ws_blk.inner_nblks == 1,
136 ws_blk.inner_idxs[0] == 1);
137 if (!ws_ok) return status::unimplemented;
138
139 ws_md_ = *hint_fwd_pd_->workspace_md();
140 }
141
142 nthr_ = dnnl_get_max_threads();
143 calculate_channel_block_size();
144 init_scratchpad();
145
146 return status::success;
147 }
148
149 dim_t channel_block_size_;
150 int nthr_; // To not exceed the limit in execute used for set up.
151
152 private:
153 void init_scratchpad() {
154 using namespace memory_tracking::names;
155 if (diff_dst_md()->data_type != data_type::f32) {
156 size_t dst_sz_ = OD() * OH() * OW();
157 size_t src_sz_ = ID() * IH() * IW();
158 auto scratchpad = scratchpad_registry().registrar();
159
160 scratchpad.template book<float>(key_pool_src_bf16cvt,
161 src_sz_ * nthr_ * channel_block_size_);
162 scratchpad.template book<float>(key_pool_dst_bf16cvt,
163 dst_sz_ * nthr_ * channel_block_size_);
164 }
165 }
166
167 void calculate_channel_block_size() {
168 // calculate channels block size at which the data fits into half
169 // of L1, it allows to improve performance for problems with small
170 // spatial
171 dim_t dst_sz_ = OD() * OH() * OW();
172 dim_t src_sz_ = ID() * IH() * IW();
173 dim_t C_per_thr = nstl::min(MB() * IC() / nthr_, IC());
174 const dim_t max_block_size
175 = platform::get_per_core_cache_size(1) / 2;
176 dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16
177 channel_block_size_ = nstl::max(
178 nstl::min(C_per_thr, max_block_size / data_size_per_ch),
179 (dim_t)1);
180 }
181 };
182
183 nchw_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {}
184 typedef typename prec_traits<d_type>::type data_t;
185
186 status_t execute(const exec_ctx_t &ctx) const override {
187 return execute_backward(ctx);
188 }
189
190private:
191 status_t execute_backward(const exec_ctx_t &ctx) const;
192 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
193};
194
195} // namespace cpu
196} // namespace impl
197} // namespace dnnl
198
199#endif
200
201// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
202