1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP
18#define CPU_NCSP_BATCH_NORMALIZATION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/platform.hpp"
30
31#include "cpu/cpu_batch_normalization_pd.hpp"
32
33namespace dnnl {
34namespace impl {
35namespace cpu {
36
37template <data_type_t d_type>
38struct ncsp_batch_normalization_fwd_t : public primitive_t {
39 struct pd_t : public cpu_batch_normalization_fwd_pd_t {
40 using cpu_batch_normalization_fwd_pd_t::
41 cpu_batch_normalization_fwd_pd_t;
42
43 DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t);
44
45 status_t init(engine_t *engine) {
46 using namespace data_type;
47 using namespace format_tag;
48
49 bool ok = is_fwd() && !has_zero_dim_memory()
50 && utils::everyone_is(
51 d_type, src_md()->data_type, dst_md()->data_type)
52 && platform::has_data_type_support(d_type)
53 && IMPLICATION(is_training(),
54 platform::has_training_support(d_type))
55 && check_scale_shift_data_type()
56 && (attr()->has_default_values()
57 || with_relu_post_op(is_training()))
58 && set_default_formats_common()
59 && memory_desc_wrapper(src_md())
60 == memory_desc_wrapper(dst_md())
61 && memory_desc_matches_one_of_tag(
62 *src_md(), ncdhw, nchw, ncw);
63 if (!ok) return status::unimplemented;
64
65 // BN+Add+Relu fusion is not currently implemented
66 if (fuse_norm_add_relu()) return status::unimplemented;
67
68 if (is_training() && fuse_norm_relu()) init_default_ws(8);
69
70 nthr_ = dnnl_get_max_threads();
71 init_scratchpad();
72
73 return status::success;
74 }
75
76 int nthr_; // To not exceed the limit in execute used for set up.
77
78 private:
79 void init_scratchpad() {
80 using namespace memory_tracking::names;
81 auto scratchpad = scratchpad_registry().registrar();
82 if (!stats_is_src()) {
83 scratchpad.template book<acc_data_t>(
84 key_bnorm_reduction, C() * nthr_);
85
86 if (!is_training()) {
87 scratchpad.template book<acc_data_t>(
88 key_bnorm_tmp_mean, C());
89 scratchpad.template book<acc_data_t>(
90 key_bnorm_tmp_var, C());
91 }
92 }
93
94 if (utils::one_of(d_type, data_type::bf16, data_type::f16)) {
95 const int simd_w = 16;
96 const int SP = D() * H() * W();
97 const int nbufs = 2;
98 const size_t cvt_buf_sz
99 = nbufs * nthr_ * utils::rnd_up(SP, simd_w);
100 scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz);
101 }
102 }
103 };
104
105 typedef typename prec_traits<d_type>::type data_t;
106 typedef float acc_data_t;
107
108 ncsp_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
109 ~ncsp_batch_normalization_fwd_t() {}
110
111 status_t execute(const exec_ctx_t &ctx) const override {
112 return execute_forward(ctx);
113 }
114
115private:
116 status_t execute_forward(const exec_ctx_t &ctx) const;
117 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
118};
119
120template <data_type_t d_type>
121struct ncsp_batch_normalization_bwd_t : public primitive_t {
122 struct pd_t : public cpu_batch_normalization_bwd_pd_t {
123 using cpu_batch_normalization_bwd_pd_t::
124 cpu_batch_normalization_bwd_pd_t;
125
126 DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t);
127
128 status_t init(engine_t *engine) {
129 using namespace data_type;
130 using namespace format_tag;
131
132 bool ok = !is_fwd() && !has_zero_dim_memory()
133 && utils::everyone_is(d_type, src_md()->data_type,
134 diff_dst_md()->data_type, diff_src_md()->data_type)
135 && platform::has_data_type_support(d_type)
136 && platform::has_training_support(d_type)
137 && check_scale_shift_data_type()
138 && attr()->has_default_values()
139 && set_default_formats_common()
140 && memory_desc_wrapper(diff_src_md())
141 == memory_desc_wrapper(diff_dst_md())
142 && memory_desc_matches_one_of_tag(
143 *src_md(), ncdhw, nchw, ncw)
144 && memory_desc_matches_one_of_tag(
145 *diff_src_md(), ncdhw, nchw, ncw);
146 if (!ok) return status::unimplemented;
147
148 // BN+Add+Relu fusion is not currently implemented
149 if (fuse_norm_add_relu()) return status::unimplemented;
150
151 if (fuse_norm_relu()) {
152 init_default_ws(8);
153 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
154 }
155
156 nthr_ = dnnl_get_max_threads();
157 init_scratchpad();
158
159 return status::success;
160 }
161
162 int nthr_; // To not exceed the limit in execute used for set up.
163
164 private:
165 void init_scratchpad() {
166 using namespace memory_tracking::names;
167 auto scratchpad = scratchpad_registry().registrar();
168 scratchpad.template book<acc_data_t>(
169 key_bnorm_reduction, 2 * C() * nthr_);
170 const auto pk_is_bwd = desc()->prop_kind == prop_kind::backward;
171 size_t ss_size = 0;
172 if (!use_scale() || !pk_is_bwd) ss_size += C();
173 if (!use_shift() || !pk_is_bwd) ss_size += C();
174
175 if (ss_size)
176 scratchpad.template book<acc_data_t>(
177 key_bnorm_tmp_diff_ss, ss_size);
178
179 if (utils::one_of(d_type, data_type::bf16, data_type::f16)) {
180 const int simd_w = 16;
181 const int SP = D() * H() * W();
182 const int nbufs = 2 + !use_global_stats();
183 const size_t cvt_buf_sz
184 = nbufs * nthr_ * utils::rnd_up(SP, simd_w);
185 scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz);
186 }
187 }
188 };
189
190 typedef typename prec_traits<d_type>::type data_t;
191 typedef float acc_data_t;
192
193 ncsp_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {}
194 ~ncsp_batch_normalization_bwd_t() {}
195
196 status_t execute(const exec_ctx_t &ctx) const override {
197 return execute_backward(ctx);
198 }
199
200private:
201 status_t execute_backward(const exec_ctx_t &ctx) const;
202 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
203};
204
205} // namespace cpu
206} // namespace impl
207} // namespace dnnl
208
209#endif
210
211// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
212