1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP
18#define CPU_NSPC_BATCH_NORMALIZATION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/type_helpers.hpp"
27#include "common/utils.hpp"
28
29#include "cpu/cpu_batch_normalization_pd.hpp"
30#include "cpu/platform.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36template <data_type_t d_type>
37struct nspc_batch_normalization_fwd_t : public primitive_t {
38 struct pd_t : public cpu_batch_normalization_fwd_pd_t {
39 pd_t(const batch_normalization_desc_t *adesc,
40 const primitive_attr_t *attr,
41 const batch_normalization_fwd_pd_t *hint_fwd_pd)
42 : cpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
43
44 DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t);
45
46 status_t init(engine_t *engine) {
47 using namespace data_type;
48 using namespace format_tag;
49
50 bool ok = is_fwd() && !has_zero_dim_memory()
51 && utils::everyone_is(
52 d_type, src_md()->data_type, dst_md()->data_type)
53 && platform::has_data_type_support(d_type)
54 && IMPLICATION(is_training(),
55 platform::has_training_support(d_type))
56 && check_scale_shift_data_type()
57 && (attr()->has_default_values()
58 || with_relu_post_op(is_training()))
59 && set_default_formats_common()
60 && memory_desc_wrapper(src_md())
61 == memory_desc_wrapper(dst_md())
62 && memory_desc_matches_one_of_tag(
63 *src_md(), ndhwc, nhwc, nwc, nc);
64 if (!ok) return status::unimplemented;
65
66 // BN+Add+Relu fusion is not currently implemented
67 if (fuse_norm_add_relu()) return status::unimplemented;
68
69 if (is_training() && fuse_norm_relu()) init_default_ws(8);
70
71 nthr_ = dnnl_get_max_threads();
72 init_scratchpad();
73
74 return status::success;
75 }
76
77 int nthr_; // To not exceed the limit in execute used for set up.
78
79 private:
80 void init_scratchpad() {
81 using namespace memory_tracking::names;
82 using namespace data_type;
83
84 auto scratchpad = scratchpad_registry().registrar();
85 if (!stats_is_src()) {
86 const size_t stats_buf_sz = nstl::max(C(), dim_t(16)) * nthr_;
87 scratchpad.template book<acc_data_t>(
88 key_bnorm_reduction, stats_buf_sz);
89 scratchpad.template book<acc_data_t>(
90 key_bnorm_tmp_mean, stats_buf_sz);
91 scratchpad.template book<acc_data_t>(
92 key_bnorm_tmp_var, stats_buf_sz);
93 }
94 if (utils::one_of(d_type, bf16, f16)) {
95 const int simd_w = 16;
96 const int nbufs = 2;
97 const size_t cvt_buf_sz
98 = nbufs * nthr_ * utils::rnd_up(C(), simd_w);
99 scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz);
100 }
101 }
102 };
103
104 typedef typename prec_traits<d_type>::type data_t;
105 typedef float acc_data_t;
106
107 nspc_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
108 ~nspc_batch_normalization_fwd_t() {}
109
110 status_t execute(const exec_ctx_t &ctx) const override {
111 return execute_forward(ctx);
112 }
113
114private:
115 status_t execute_forward(const exec_ctx_t &ctx) const;
116 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
117};
118
119template <data_type_t d_type>
120struct nspc_batch_normalization_bwd_t : public primitive_t {
121 struct pd_t : public cpu_batch_normalization_bwd_pd_t {
122 pd_t(const batch_normalization_desc_t *adesc,
123 const primitive_attr_t *attr,
124 const batch_normalization_fwd_pd_t *hint_fwd_pd)
125 : cpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
126
127 DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t);
128
129 status_t init(engine_t *engine) {
130 using namespace data_type;
131 using namespace format_tag;
132
133 bool ok = !is_fwd() && !has_zero_dim_memory()
134 && utils::everyone_is(d_type, src_md()->data_type,
135 diff_dst_md()->data_type, diff_src_md()->data_type)
136 && platform::has_data_type_support(d_type)
137 && platform::has_training_support(d_type)
138 && check_scale_shift_data_type()
139 && attr()->has_default_values()
140 && set_default_formats_common()
141 && memory_desc_wrapper(diff_src_md())
142 == memory_desc_wrapper(diff_dst_md())
143 && memory_desc_matches_one_of_tag(
144 *src_md(), ndhwc, nhwc, nwc, nc)
145 && memory_desc_matches_one_of_tag(
146 *diff_src_md(), ndhwc, nhwc, nwc, nc);
147 if (!ok) return status::unimplemented;
148
149 // BN+Add+Relu fusion is not currently implemented
150 if (fuse_norm_add_relu()) return status::unimplemented;
151
152 if (fuse_norm_relu()) {
153 init_default_ws(8);
154 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
155 }
156
157 nthr_ = dnnl_get_max_threads();
158 init_scratchpad();
159
160 return status::success;
161 }
162
163 int nthr_; // To not exceed the limit in execute used for set up.
164
165 private:
166 void init_scratchpad() {
167 using namespace memory_tracking::names;
168 using namespace data_type;
169
170 auto scratchpad = scratchpad_registry().registrar();
171 scratchpad.template book<acc_data_t>(
172 key_bnorm_reduction, 2 * C() * nthr_);
173 scratchpad.template book<acc_data_t>(
174 key_bnorm_tmp_diff_ss, 2 * C() * (nthr_ + 1));
175 if (utils::one_of(d_type, bf16, f16)) {
176 const int simd_w = 16;
177 const int nbufs = 2 + !use_global_stats();
178 const size_t cvt_buf_sz
179 = nbufs * nthr_ * utils::rnd_up(C(), simd_w);
180 scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz);
181 }
182 }
183 };
184
185 typedef typename prec_traits<d_type>::type data_t;
186 typedef float acc_data_t;
187
188 nspc_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {}
189 ~nspc_batch_normalization_bwd_t() {}
190
191 status_t execute(const exec_ctx_t &ctx) const override {
192 return execute_backward(ctx);
193 }
194
195private:
196 status_t execute_backward(const exec_ctx_t &ctx) const;
197 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
198};
199
200} // namespace cpu
201} // namespace impl
202} // namespace dnnl
203
204#endif
205
206// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
207