1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP
18#define GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP
19
20#include "common/experimental.hpp"
21#include "common/primitive.hpp"
22#include "gpu/compute/compute.hpp"
23#include "gpu/gpu_batch_normalization_pd.hpp"
24#include "gpu/gpu_primitive.hpp"
25#include "gpu/gpu_resource.hpp"
26#include "gpu/primitive_conf.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace gpu {
31namespace ocl {
32
33struct gen9_batch_normalization_fwd_t : public gpu_primitive_t {
34 using gpu_primitive_t::gpu_primitive_t;
35 struct pd_t : public gpu_batch_normalization_fwd_pd_t {
36 pd_t(const batch_normalization_desc_t *adesc,
37 const primitive_attr_t *attr,
38 const batch_normalization_fwd_pd_t *hint_fwd_pd)
39 : gpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
40
41 DECLARE_COMMON_PD_T(impl_name(), gen9_batch_normalization_fwd_t);
42
43 const char *impl_name() const {
44 return conf.nhwc_optimized
45 ? (conf.use_stats_one_pass ? "ocl:gen9:nhwc:onepass"
46 : "ocl:gen9:nhwc")
47 : (conf.use_stats_one_pass ? "ocl:gen9:blocked:onepass"
48 : "ocl:gen9:blocked");
49 }
50 status_t init(engine_t *engine) {
51 using namespace data_type;
52 auto *compute_engine
53 = utils::downcast<compute::compute_engine_t *>(engine);
54
55 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
56
57 bool ok = is_fwd()
58 && utils::one_of(src_md()->data_type, f32, bf16, f16, s8)
59 && src_md()->data_type == dst_md()->data_type
60 && IMPLICATION(src_md()->data_type == s8,
61 !is_training() && stats_is_src())
62 && check_scale_shift_data_type()
63 && attr()->has_default_values(attr_skip_mask)
64 && IMPLICATION(!attr()->has_default_values(),
65 attr()->post_ops_.len() == 1 && with_relu_post_op())
66 && set_default_formats_common()
67 && memory_desc_wrapper(src_md())
68 == memory_desc_wrapper(dst_md())
69 && compute_engine->mayiuse(
70 compute::device_ext_t::intel_subgroups);
71 if (!ok) return status::unimplemented;
72
73 if (is_training() && (fuse_norm_relu() || fuse_norm_add_relu()))
74 init_default_ws(8);
75
76 status_t status = init_conf(engine);
77 if (status != status::success) return status;
78 init_scratchpad();
79
80 return status::success;
81 }
82
83 status_t init_conf(engine_t *engine);
84 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
85 void init_scratchpad();
86
87 bnorm_conf_t conf;
88 offsets_t off;
89 };
90
91 status_t init(engine_t *engine) override {
92 compute::kernel_ctx_t kernel_ctx;
93
94 status_t status = pd()->init_kernel_ctx(kernel_ctx);
95 CHECK(status);
96
97 std::vector<const char *> kernel_names = {"gen9_bnorm_fwd", nullptr,
98 nullptr, nullptr, nullptr, nullptr, nullptr};
99 if (pd()->conf.calculate_stats) {
100 if (pd()->conf.use_stats_one_pass) {
101 kernel_names[1] = "gen9_calc_mean_var";
102 kernel_names[2] = "gen9_reduce_mean_var";
103 kernel_names[3] = "gen9_fused_reduce_init";
104 kernel_names[4] = "gen9_fused_reduce_final";
105 } else {
106 kernel_names[1] = "gen9_calc_mean";
107 kernel_names[2] = "gen9_calc_variance";
108 kernel_names[3] = "gen9_reduce_mean";
109 kernel_names[4] = "gen9_reduce_variance";
110 kernel_names[5] = "gen9_fused_reduce_init";
111 kernel_names[6] = "gen9_fused_reduce_final";
112 }
113 }
114
115 std::vector<compute::kernel_t> kernels;
116 status = create_kernels(engine, &kernels, kernel_names, kernel_ctx);
117 CHECK(status);
118
119 kernel_ = kernels[0];
120 if (pd()->conf.use_stats_one_pass) {
121 calculate_mean_var_kernel_ = kernels[1];
122 reduce_mean_var_kernel_ = kernels[2];
123 reduce_init_kernel_ = kernels[3];
124 reduce_final_kernel_ = kernels[4];
125 } else {
126 calculate_mean_kernel_ = kernels[1];
127 calculate_variance_kernel_ = kernels[2];
128 reduce_mean_kernel_ = kernels[3];
129 reduce_variance_kernel_ = kernels[4];
130 reduce_init_kernel_ = kernels[5];
131 reduce_final_kernel_ = kernels[6];
132 }
133
134 return status::success;
135 }
136
137 status_t execute(const exec_ctx_t &ctx) const override {
138 return execute_forward(ctx);
139 }
140
141private:
142 status_t execute_forward(const exec_ctx_t &ctx) const;
143 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
144 compute::kernel_t kernel_;
145 compute::kernel_t calculate_mean_kernel_;
146 compute::kernel_t reduce_mean_kernel_;
147 compute::kernel_t calculate_variance_kernel_;
148 compute::kernel_t reduce_variance_kernel_;
149 compute::kernel_t calculate_mean_var_kernel_;
150 compute::kernel_t reduce_mean_var_kernel_;
151 compute::kernel_t reduce_init_kernel_;
152 compute::kernel_t reduce_final_kernel_;
153};
154
155struct gen9_batch_normalization_bwd_t : public gpu_primitive_t {
156 using gpu_primitive_t::gpu_primitive_t;
157 struct pd_t : public gpu_batch_normalization_bwd_pd_t {
158 pd_t(const batch_normalization_desc_t *adesc,
159 const primitive_attr_t *attr,
160 const batch_normalization_fwd_pd_t *hint_fwd_pd)
161 : gpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
162
163 DECLARE_COMMON_PD_T(impl_name(), gen9_batch_normalization_bwd_t);
164
165 const char *impl_name() const {
166 return conf.nhwc_optimized ? "ocl:gen9:nhwc" : "ocl:gen9:blocked";
167 }
168
169 status_t init(engine_t *engine) {
170 auto *compute_engine
171 = utils::downcast<compute::compute_engine_t *>(engine);
172 using namespace data_type;
173
174 bool ok = !is_fwd() && utils::one_of(src_md()->data_type, f32, bf16)
175 && src_md()->data_type == diff_src_md()->data_type
176 && diff_src_md()->data_type == diff_dst_md()->data_type
177 && check_scale_shift_data_type()
178 && attr()->has_default_values()
179 && set_default_formats_common()
180 && memory_desc_wrapper(diff_src_md())
181 == memory_desc_wrapper(diff_dst_md())
182 && compute_engine->mayiuse(
183 compute::device_ext_t::intel_subgroups);
184 if (!ok) return status::unimplemented;
185
186 if (fuse_norm_relu() || fuse_norm_add_relu()) {
187 init_default_ws(8);
188 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
189 }
190
191 status_t status = init_conf(engine);
192 if (status != status::success) return status;
193 init_scratchpad();
194
195 return status::success;
196 }
197
198 status_t init_conf(engine_t *engine);
199 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
200 void init_scratchpad();
201
202 bnorm_conf_t conf;
203 offsets_t off;
204 };
205
206 status_t init(engine_t *engine) override {
207 compute::kernel_ctx_t kernel_ctx;
208
209 status_t status = pd()->init_kernel_ctx(kernel_ctx);
210 CHECK(status);
211
212 std::vector<const char *> kernel_names = {"gen9_bnorm_bwd",
213 "gen9_calculate_stats", "gen9_reduce_stats",
214 "gen9_fused_reduce_init", "gen9_fused_reduce_final"};
215
216 std::vector<compute::kernel_t> kernels;
217 status = create_kernels(engine, &kernels, kernel_names, kernel_ctx);
218 CHECK(status);
219
220 bwd_kernel_ = kernels[0];
221 calculate_stats_kernel_ = kernels[1];
222 reduce_stats_kernel_ = kernels[2];
223 reduce_init_kernel_ = kernels[3];
224 reduce_final_kernel_ = kernels[4];
225
226 return status::success;
227 }
228
229 status_t execute(const exec_ctx_t &ctx) const override {
230 return execute_backward(ctx);
231 }
232
233private:
234 status_t execute_backward(const exec_ctx_t &ctx) const;
235 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
236 compute::kernel_t bwd_kernel_;
237 compute::kernel_t calculate_stats_kernel_;
238 compute::kernel_t reduce_stats_kernel_;
239 compute::kernel_t reduce_init_kernel_;
240 compute::kernel_t reduce_final_kernel_;
241};
242
243} // namespace ocl
244} // namespace gpu
245} // namespace impl
246} // namespace dnnl
247
248#endif // GPU_OCL_GEN9_BATCH_NORMALIZATION_HPP
249