1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_NHWC_POOLING_HPP
18#define CPU_NHWC_POOLING_HPP
19
20#include <assert.h>
21
22#include "common/bfloat16.hpp"
23#include "common/c_types_map.hpp"
24#include "common/dnnl_thread.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_pooling_pd.hpp"
30#include "cpu/platform.hpp"
31#include "cpu/primitive_attr_postops.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37namespace nhwc_pooling {
38size_t strided_offset(const int _n, const size_t _sn, const int _d,
39 const size_t _sd, const int _h, const size_t _sh, const int _w,
40 const size_t _sw);
41}
42
43template <data_type_t d_type>
44struct nhwc_pooling_fwd_t : public primitive_t {
45 struct pd_t : public cpu_pooling_fwd_pd_t {
46 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
47
48 DECLARE_COMMON_PD_T("simple_nhwc:any", nhwc_pooling_fwd_t);
49
50 status_t init(engine_t *engine) {
51 const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3,
52 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
53
54 using namespace prop_kind;
55 using namespace alg_kind;
56 const bool ok = is_fwd()
57 && utils::one_of(desc()->alg_kind, pooling_max,
58 pooling_avg_include_padding,
59 pooling_avg_exclude_padding)
60 && utils::everyone_is(
61 d_type, src_md()->data_type, dst_md()->data_type)
62 && platform::has_data_type_support(d_type) && !is_dilated()
63 && attr()->has_default_values(
64 primitive_attr_t::skip_mask_t::post_ops, d_type)
65 && set_default_params() == status::success
66 && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
67 && memory_desc_matches_tag(*dst_md(), desired_fmt_tag)
68 && attr_.set_default_formats(dst_md(0)) == status::success;
69 if (!ok) return status::unimplemented;
70
71 const bool is_training = desc_.prop_kind == forward_training;
72 if (desc()->alg_kind == pooling_max && is_training) {
73 init_default_ws();
74 }
75
76 nthr_ = dnnl_get_max_threads();
77 init_scratchpad();
78
79 return status::success;
80 }
81
82 int nthr_; // To not exceed the limit in execute used for set up.
83
84 private:
85 void init_scratchpad() {
86 using namespace memory_tracking::names;
87 if (src_md()->data_type != data_type::f32) {
88 const size_t bf16cvt_sz_ = IC() * nthr_;
89 auto scratchpad = scratchpad_registry().registrar();
90 scratchpad.template book<float>(
91 key_pool_src_bf16cvt, bf16cvt_sz_);
92 scratchpad.template book<float>(
93 key_pool_dst_bf16cvt, bf16cvt_sz_);
94 }
95 }
96 };
97
98 nhwc_pooling_fwd_t(const pd_t *apd);
99
100 using data_t = typename prec_traits<d_type>::type;
101 using ker_data_t = typename prec_traits<data_type::f32>::type;
102
103 status_t execute(const exec_ctx_t &ctx) const override {
104 return execute_forward(ctx);
105 }
106
107private:
108 status_t execute_forward(const exec_ctx_t &ctx) const;
109 void array_div_by_const(const int n, const ker_data_t *src,
110 const size_t num, ker_data_t *dst) const;
111 void array_add(const int n, const ker_data_t *src, ker_data_t *dst) const;
112 void array_nhwc_max(const int n, ker_data_t *dst, const ker_data_t *src,
113 unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt,
114 const int index) const;
115 void array_nhwc_initialize(const int n, ker_data_t *dst, unsigned char *ws,
116 const size_t ws_offset, const data_type_t ws_dt) const;
117
118 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
119 const ref_post_ops_t ref_post_ops_;
120};
121
122template <impl::data_type_t d_type>
123struct nhwc_pooling_bwd_t : public primitive_t {
124 struct pd_t : public cpu_pooling_bwd_pd_t {
125 using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
126
127 DECLARE_COMMON_PD_T("simple_nhwc:any", nhwc_pooling_bwd_t);
128
129 status_t init(engine_t *engine) {
130 const format_tag_t desired_fmt_tag = utils::pick(ndims() - 3,
131 format_tag::nwc, format_tag::nhwc, format_tag::ndhwc);
132
133 using namespace prop_kind;
134 using namespace alg_kind;
135 bool ok = !is_fwd()
136 && utils::one_of(desc()->alg_kind, pooling_max,
137 pooling_avg_include_padding,
138 pooling_avg_exclude_padding)
139 && utils::everyone_is(d_type, diff_dst_md()->data_type,
140 diff_src_md()->data_type)
141 && platform::has_data_type_support(d_type)
142 && set_default_params() == status::success && !is_fwd()
143 && attr()->has_default_values()
144 && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
145 && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag)
146 && !is_dilated();
147 if (!ok) return status::unimplemented;
148
149 if (desc()->alg_kind == pooling_max) {
150 init_default_ws();
151 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
152 }
153
154 nthr_ = dnnl_get_max_threads();
155 init_scratchpad();
156
157 return status::success;
158 }
159
160 int nthr_; // To not exceed the limit in execute used for set up.
161
162 private:
163 void init_scratchpad() {
164 using namespace memory_tracking::names;
165 if (diff_src_md()->data_type != data_type::f32) {
166 size_t bf16cvt_sz_ = IC() * nthr_;
167 auto scratchpad = scratchpad_registry().registrar();
168 scratchpad.template book<float>(
169 key_pool_src_bf16cvt, bf16cvt_sz_);
170 scratchpad.template book<float>(
171 key_pool_dst_bf16cvt, bf16cvt_sz_);
172 }
173 }
174 };
175
176 nhwc_pooling_bwd_t(const pd_t *apd) : primitive_t(apd) {}
177 typedef typename prec_traits<d_type>::type data_t;
178
179 status_t execute(const exec_ctx_t &ctx) const override {
180 return execute_backward(ctx);
181 }
182
183private:
184 status_t execute_backward(const exec_ctx_t &ctx) const;
185 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
186};
187
188} // namespace cpu
189} // namespace impl
190} // namespace dnnl
191
192#endif
193
194// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
195