1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP |
18 | #define GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP |
19 | |
20 | #include "common/experimental.hpp" |
21 | #include "common/primitive.hpp" |
22 | #include "gpu/compute/compute.hpp" |
23 | #include "gpu/gpu_batch_normalization_pd.hpp" |
24 | #include "gpu/gpu_primitive.hpp" |
25 | #include "gpu/gpu_resource.hpp" |
26 | #include "gpu/primitive_conf.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace gpu { |
31 | namespace ocl { |
32 | |
33 | struct gen9_batch_normalization_fwd_t : public gpu_primitive_t { |
34 | using gpu_primitive_t::gpu_primitive_t; |
35 | struct pd_t : public gpu_batch_normalization_fwd_pd_t { |
36 | pd_t(const batch_normalization_desc_t *adesc, |
37 | const primitive_attr_t *attr, |
38 | const batch_normalization_fwd_pd_t *hint_fwd_pd) |
39 | : gpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {} |
40 | |
41 | DECLARE_COMMON_PD_T(impl_name(), gen9_batch_normalization_fwd_t); |
42 | |
43 | const char *impl_name() const { |
44 | return conf.nhwc_optimized |
45 | ? (conf.use_stats_one_pass ? "ocl:gen9:nhwc:onepass" |
46 | : "ocl:gen9:nhwc" ) |
47 | : (conf.use_stats_one_pass ? "ocl:gen9:blocked:onepass" |
48 | : "ocl:gen9:blocked" ); |
49 | } |
50 | status_t init(engine_t *engine) { |
51 | using namespace data_type; |
52 | auto *compute_engine |
53 | = utils::downcast<compute::compute_engine_t *>(engine); |
54 | |
55 | const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops; |
56 | |
57 | bool ok = is_fwd() |
58 | && utils::one_of(src_md()->data_type, f32, bf16, f16, s8) |
59 | && src_md()->data_type == dst_md()->data_type |
60 | && IMPLICATION(src_md()->data_type == s8, |
61 | !is_training() && stats_is_src()) |
62 | && check_scale_shift_data_type() |
63 | && attr()->has_default_values(attr_skip_mask) |
64 | && IMPLICATION(!attr()->has_default_values(), |
65 | attr()->post_ops_.len() == 1 && with_relu_post_op()) |
66 | && set_default_formats_common() |
67 | && memory_desc_wrapper(src_md()) |
68 | == memory_desc_wrapper(dst_md()) |
69 | && compute_engine->mayiuse( |
70 | compute::device_ext_t::intel_subgroups); |
71 | if (!ok) return status::unimplemented; |
72 | |
73 | if (is_training() && (fuse_norm_relu() || fuse_norm_add_relu())) |
74 | init_default_ws(8); |
75 | |
76 | status_t status = init_conf(engine); |
77 | if (status != status::success) return status; |
78 | init_scratchpad(); |
79 | |
80 | return status::success; |
81 | } |
82 | |
83 | status_t init_conf(engine_t *engine); |
84 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
85 | void init_scratchpad(); |
86 | |
87 | bnorm_conf_t conf; |
88 | offsets_t off; |
89 | }; |
90 | |
91 | status_t init(engine_t *engine) override { |
92 | compute::kernel_ctx_t kernel_ctx; |
93 | |
94 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
95 | CHECK(status); |
96 | |
97 | std::vector<const char *> kernel_names = {"gen9_bnorm_fwd" , nullptr, |
98 | nullptr, nullptr, nullptr, nullptr, nullptr}; |
99 | if (pd()->conf.calculate_stats) { |
100 | if (pd()->conf.use_stats_one_pass) { |
101 | kernel_names[1] = "gen9_calc_mean_var" ; |
102 | kernel_names[2] = "gen9_reduce_mean_var" ; |
103 | kernel_names[3] = "gen9_fused_reduce_init" ; |
104 | kernel_names[4] = "gen9_fused_reduce_final" ; |
105 | } else { |
106 | kernel_names[1] = "gen9_calc_mean" ; |
107 | kernel_names[2] = "gen9_calc_variance" ; |
108 | kernel_names[3] = "gen9_reduce_mean" ; |
109 | kernel_names[4] = "gen9_reduce_variance" ; |
110 | kernel_names[5] = "gen9_fused_reduce_init" ; |
111 | kernel_names[6] = "gen9_fused_reduce_final" ; |
112 | } |
113 | } |
114 | |
115 | std::vector<compute::kernel_t> kernels; |
116 | status = create_kernels(engine, &kernels, kernel_names, kernel_ctx); |
117 | CHECK(status); |
118 | |
119 | kernel_ = kernels[0]; |
120 | if (pd()->conf.use_stats_one_pass) { |
121 | calculate_mean_var_kernel_ = kernels[1]; |
122 | reduce_mean_var_kernel_ = kernels[2]; |
123 | reduce_init_kernel_ = kernels[3]; |
124 | reduce_final_kernel_ = kernels[4]; |
125 | } else { |
126 | calculate_mean_kernel_ = kernels[1]; |
127 | calculate_variance_kernel_ = kernels[2]; |
128 | reduce_mean_kernel_ = kernels[3]; |
129 | reduce_variance_kernel_ = kernels[4]; |
130 | reduce_init_kernel_ = kernels[5]; |
131 | reduce_final_kernel_ = kernels[6]; |
132 | } |
133 | |
134 | return status::success; |
135 | } |
136 | |
137 | status_t execute(const exec_ctx_t &ctx) const override { |
138 | return execute_forward(ctx); |
139 | } |
140 | |
141 | private: |
142 | status_t execute_forward(const exec_ctx_t &ctx) const; |
143 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
144 | compute::kernel_t kernel_; |
145 | compute::kernel_t calculate_mean_kernel_; |
146 | compute::kernel_t reduce_mean_kernel_; |
147 | compute::kernel_t calculate_variance_kernel_; |
148 | compute::kernel_t reduce_variance_kernel_; |
149 | compute::kernel_t calculate_mean_var_kernel_; |
150 | compute::kernel_t reduce_mean_var_kernel_; |
151 | compute::kernel_t reduce_init_kernel_; |
152 | compute::kernel_t reduce_final_kernel_; |
153 | }; |
154 | |
155 | struct gen9_batch_normalization_bwd_t : public gpu_primitive_t { |
156 | using gpu_primitive_t::gpu_primitive_t; |
157 | struct pd_t : public gpu_batch_normalization_bwd_pd_t { |
158 | pd_t(const batch_normalization_desc_t *adesc, |
159 | const primitive_attr_t *attr, |
160 | const batch_normalization_fwd_pd_t *hint_fwd_pd) |
161 | : gpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {} |
162 | |
163 | DECLARE_COMMON_PD_T(impl_name(), gen9_batch_normalization_bwd_t); |
164 | |
165 | const char *impl_name() const { |
166 | return conf.nhwc_optimized ? "ocl:gen9:nhwc" : "ocl:gen9:blocked" ; |
167 | } |
168 | |
169 | status_t init(engine_t *engine) { |
170 | auto *compute_engine |
171 | = utils::downcast<compute::compute_engine_t *>(engine); |
172 | using namespace data_type; |
173 | |
174 | bool ok = !is_fwd() && utils::one_of(src_md()->data_type, f32, bf16) |
175 | && src_md()->data_type == diff_src_md()->data_type |
176 | && diff_src_md()->data_type == diff_dst_md()->data_type |
177 | && check_scale_shift_data_type() |
178 | && attr()->has_default_values() |
179 | && set_default_formats_common() |
180 | && memory_desc_wrapper(diff_src_md()) |
181 | == memory_desc_wrapper(diff_dst_md()) |
182 | && compute_engine->mayiuse( |
183 | compute::device_ext_t::intel_subgroups); |
184 | if (!ok) return status::unimplemented; |
185 | |
186 | if (fuse_norm_relu() || fuse_norm_add_relu()) { |
187 | init_default_ws(8); |
188 | if (!compare_ws(hint_fwd_pd_)) return status::unimplemented; |
189 | } |
190 | |
191 | status_t status = init_conf(engine); |
192 | if (status != status::success) return status; |
193 | init_scratchpad(); |
194 | |
195 | return status::success; |
196 | } |
197 | |
198 | status_t init_conf(engine_t *engine); |
199 | status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const; |
200 | void init_scratchpad(); |
201 | |
202 | bnorm_conf_t conf; |
203 | offsets_t off; |
204 | }; |
205 | |
206 | status_t init(engine_t *engine) override { |
207 | compute::kernel_ctx_t kernel_ctx; |
208 | |
209 | status_t status = pd()->init_kernel_ctx(kernel_ctx); |
210 | CHECK(status); |
211 | |
212 | std::vector<const char *> kernel_names = {"gen9_bnorm_bwd" , |
213 | "gen9_calculate_stats" , "gen9_reduce_stats" , |
214 | "gen9_fused_reduce_init" , "gen9_fused_reduce_final" }; |
215 | |
216 | std::vector<compute::kernel_t> kernels; |
217 | status = create_kernels(engine, &kernels, kernel_names, kernel_ctx); |
218 | CHECK(status); |
219 | |
220 | bwd_kernel_ = kernels[0]; |
221 | calculate_stats_kernel_ = kernels[1]; |
222 | reduce_stats_kernel_ = kernels[2]; |
223 | reduce_init_kernel_ = kernels[3]; |
224 | reduce_final_kernel_ = kernels[4]; |
225 | |
226 | return status::success; |
227 | } |
228 | |
229 | status_t execute(const exec_ctx_t &ctx) const override { |
230 | return execute_backward(ctx); |
231 | } |
232 | |
233 | private: |
234 | status_t execute_backward(const exec_ctx_t &ctx) const; |
235 | const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } |
236 | compute::kernel_t bwd_kernel_; |
237 | compute::kernel_t calculate_stats_kernel_; |
238 | compute::kernel_t reduce_stats_kernel_; |
239 | compute::kernel_t reduce_init_kernel_; |
240 | compute::kernel_t reduce_final_kernel_; |
241 | }; |
242 | |
243 | } // namespace ocl |
244 | } // namespace gpu |
245 | } // namespace impl |
246 | } // namespace dnnl |
247 | |
248 | #endif // GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP |
249 | |