1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP |
18 | #define CPU_NCSP_BATCH_NORMALIZATION_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/dnnl_thread.hpp" |
24 | #include "common/memory_tracking.hpp" |
25 | #include "common/primitive.hpp" |
26 | #include "common/type_helpers.hpp" |
27 | #include "common/utils.hpp" |
28 | |
29 | #include "cpu/platform.hpp" |
30 | |
31 | #include "cpu/cpu_batch_normalization_pd.hpp" |
32 | |
33 | namespace dnnl { |
34 | namespace impl { |
35 | namespace cpu { |
36 | |
37 | template <data_type_t d_type> |
38 | struct ncsp_batch_normalization_fwd_t : public primitive_t { |
39 | struct pd_t : public cpu_batch_normalization_fwd_pd_t { |
40 | using cpu_batch_normalization_fwd_pd_t:: |
41 | cpu_batch_normalization_fwd_pd_t; |
42 | |
43 | DECLARE_COMMON_PD_T("ncsp_bnorm:any" , ncsp_batch_normalization_fwd_t); |
44 | |
45 | status_t init(engine_t *engine) { |
46 | using namespace data_type; |
47 | using namespace format_tag; |
48 | |
49 | bool ok = is_fwd() && !has_zero_dim_memory() |
50 | && utils::everyone_is( |
51 | d_type, src_md()->data_type, dst_md()->data_type) |
52 | && platform::has_data_type_support(d_type) |
53 | && IMPLICATION(is_training(), |
54 | platform::has_training_support(d_type)) |
55 | && check_scale_shift_data_type() |
56 | && (attr()->has_default_values() |
57 | || with_relu_post_op(is_training())) |
58 | && set_default_formats_common() |
59 | && memory_desc_wrapper(src_md()) |
60 | == memory_desc_wrapper(dst_md()) |
61 | && memory_desc_matches_one_of_tag( |
62 | *src_md(), ncdhw, nchw, ncw); |
63 | if (!ok) return status::unimplemented; |
64 | |
65 | // BN+Add+Relu fusion is not currently implemented |
66 | if (fuse_norm_add_relu()) return status::unimplemented; |
67 | |
68 | if (is_training() && fuse_norm_relu()) init_default_ws(8); |
69 | |
70 | nthr_ = dnnl_get_max_threads(); |
71 | init_scratchpad(); |
72 | |
73 | return status::success; |
74 | } |
75 | |
76 | int nthr_; // To not exceed the limit in execute used for set up. |
77 | |
78 | private: |
79 | void init_scratchpad() { |
80 | using namespace memory_tracking::names; |
81 | auto scratchpad = scratchpad_registry().registrar(); |
82 | if (!stats_is_src()) { |
83 | scratchpad.template book<acc_data_t>( |
84 | key_bnorm_reduction, C() * nthr_); |
85 | |
86 | if (!is_training()) { |
87 | scratchpad.template book<acc_data_t>( |
88 | key_bnorm_tmp_mean, C()); |
89 | scratchpad.template book<acc_data_t>( |
90 | key_bnorm_tmp_var, C()); |
91 | } |
92 | } |
93 | |
94 | if (utils::one_of(d_type, data_type::bf16, data_type::f16)) { |
95 | const int simd_w = 16; |
96 | const int SP = D() * H() * W(); |
97 | const int nbufs = 2; |
98 | const size_t cvt_buf_sz |
99 | = nbufs * nthr_ * utils::rnd_up(SP, simd_w); |
100 | scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz); |
101 | } |
102 | } |
103 | }; |
104 | |
105 | typedef typename prec_traits<d_type>::type data_t; |
106 | typedef float acc_data_t; |
107 | |
108 | ncsp_batch_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {} |
109 | ~ncsp_batch_normalization_fwd_t() {} |
110 | |
111 | status_t execute(const exec_ctx_t &ctx) const override { |
112 | return execute_forward(ctx); |
113 | } |
114 | |
115 | private: |
116 | status_t execute_forward(const exec_ctx_t &ctx) const; |
117 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
118 | }; |
119 | |
120 | template <data_type_t d_type> |
121 | struct ncsp_batch_normalization_bwd_t : public primitive_t { |
122 | struct pd_t : public cpu_batch_normalization_bwd_pd_t { |
123 | using cpu_batch_normalization_bwd_pd_t:: |
124 | cpu_batch_normalization_bwd_pd_t; |
125 | |
126 | DECLARE_COMMON_PD_T("ncsp_bnorm:any" , ncsp_batch_normalization_bwd_t); |
127 | |
128 | status_t init(engine_t *engine) { |
129 | using namespace data_type; |
130 | using namespace format_tag; |
131 | |
132 | bool ok = !is_fwd() && !has_zero_dim_memory() |
133 | && utils::everyone_is(d_type, src_md()->data_type, |
134 | diff_dst_md()->data_type, diff_src_md()->data_type) |
135 | && platform::has_data_type_support(d_type) |
136 | && platform::has_training_support(d_type) |
137 | && check_scale_shift_data_type() |
138 | && attr()->has_default_values() |
139 | && set_default_formats_common() |
140 | && memory_desc_wrapper(diff_src_md()) |
141 | == memory_desc_wrapper(diff_dst_md()) |
142 | && memory_desc_matches_one_of_tag( |
143 | *src_md(), ncdhw, nchw, ncw) |
144 | && memory_desc_matches_one_of_tag( |
145 | *diff_src_md(), ncdhw, nchw, ncw); |
146 | if (!ok) return status::unimplemented; |
147 | |
148 | // BN+Add+Relu fusion is not currently implemented |
149 | if (fuse_norm_add_relu()) return status::unimplemented; |
150 | |
151 | if (fuse_norm_relu()) { |
152 | init_default_ws(8); |
153 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
154 | } |
155 | |
156 | nthr_ = dnnl_get_max_threads(); |
157 | init_scratchpad(); |
158 | |
159 | return status::success; |
160 | } |
161 | |
162 | int nthr_; // To not exceed the limit in execute used for set up. |
163 | |
164 | private: |
165 | void init_scratchpad() { |
166 | using namespace memory_tracking::names; |
167 | auto scratchpad = scratchpad_registry().registrar(); |
168 | scratchpad.template book<acc_data_t>( |
169 | key_bnorm_reduction, 2 * C() * nthr_); |
170 | const auto pk_is_bwd = desc()->prop_kind == prop_kind::backward; |
171 | size_t ss_size = 0; |
172 | if (!use_scale() || !pk_is_bwd) ss_size += C(); |
173 | if (!use_shift() || !pk_is_bwd) ss_size += C(); |
174 | |
175 | if (ss_size) |
176 | scratchpad.template book<acc_data_t>( |
177 | key_bnorm_tmp_diff_ss, ss_size); |
178 | |
179 | if (utils::one_of(d_type, data_type::bf16, data_type::f16)) { |
180 | const int simd_w = 16; |
181 | const int SP = D() * H() * W(); |
182 | const int nbufs = 2 + !use_global_stats(); |
183 | const size_t cvt_buf_sz |
184 | = nbufs * nthr_ * utils::rnd_up(SP, simd_w); |
185 | scratchpad.template book<acc_data_t>(key_bnorm_cvt, cvt_buf_sz); |
186 | } |
187 | } |
188 | }; |
189 | |
190 | typedef typename prec_traits<d_type>::type data_t; |
191 | typedef float acc_data_t; |
192 | |
193 | ncsp_batch_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
194 | ~ncsp_batch_normalization_bwd_t() {} |
195 | |
196 | status_t execute(const exec_ctx_t &ctx) const override { |
197 | return execute_backward(ctx); |
198 | } |
199 | |
200 | private: |
201 | status_t execute_backward(const exec_ctx_t &ctx) const; |
202 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
203 | }; |
204 | |
205 | } // namespace cpu |
206 | } // namespace impl |
207 | } // namespace dnnl |
208 | |
209 | #endif |
210 | |
211 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
212 | |