1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_BATCH_NORMALIZATION_HPP
18#define GPU_OCL_REF_BATCH_NORMALIZATION_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/primitive.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_batch_normalization_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/gpu_resource.hpp"
28#include "gpu/ocl/ocl_stream.hpp"
29#include "gpu/ocl/ocl_utils.hpp"
30#include "gpu/primitive_conf.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace gpu {
35namespace ocl {
36
37struct ref_batch_normalization_fwd_t : public gpu_primitive_t {
38 using gpu_primitive_t::gpu_primitive_t;
39 struct pd_t : public gpu_batch_normalization_fwd_pd_t {
40 pd_t(const batch_normalization_desc_t *adesc,
41 const primitive_attr_t *attr,
42 const batch_normalization_fwd_pd_t *hint_fwd_pd)
43 : gpu_batch_normalization_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
44
45 DECLARE_COMMON_PD_T("ocl:ref:any", ref_batch_normalization_fwd_t);
46
47 status_t init(engine_t *engine) {
48 using namespace data_type;
49 auto *compute_engine
50 = utils::downcast<compute::compute_engine_t *>(engine);
51
52 const auto attr_skip_mask = primitive_attr_t::skip_mask_t::post_ops;
53
54 bool ok = is_fwd()
55 && utils::one_of(src_md()->data_type, f32, bf16, f16, s8)
56 && src_md()->data_type == dst_md()->data_type
57 && IMPLICATION(src_md()->data_type == s8,
58 !is_training() && stats_is_src())
59 && check_scale_shift_data_type()
60 && attr()->has_default_values(attr_skip_mask)
61 && IMPLICATION(!attr()->has_default_values(),
62 attr()->post_ops_.len() == 1 && with_relu_post_op())
63 && set_default_formats_common()
64 && memory_desc_wrapper(src_md())
65 == memory_desc_wrapper(dst_md())
66 && compute_engine->mayiuse(
67 compute::device_ext_t::intel_subgroups);
68 if (!ok) return status::unimplemented;
69
70 if (is_training() && (fuse_norm_relu() || fuse_norm_add_relu()))
71 init_default_ws(8);
72
73 status_t status = init_conf(engine);
74 if (status != status::success) return status;
75 init_scratchpad();
76
77 return status::success;
78 }
79
80 status_t init_conf(engine_t *engine);
81 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
82 void init_scratchpad();
83
84 bnorm_conf_t conf;
85 offsets_t off;
86 };
87
88 status_t init(engine_t *engine) override {
89 compute::kernel_ctx_t kernel_ctx;
90
91 status_t status = pd()->init_kernel_ctx(kernel_ctx);
92 CHECK(status);
93
94 std::vector<const char *> kernel_names = {
95 "ref_bnorm_fwd", nullptr, nullptr, nullptr, nullptr, nullptr};
96 if (pd()->conf.calculate_stats) {
97 kernel_names[1] = "calculate_mean";
98 kernel_names[2] = "calculate_variance";
99 kernel_names[3] = "reduce_mean";
100 kernel_names[4] = "reduce_variance";
101 }
102
103 if (pd()->conf.skip_reduce_stat) {
104 kernel_names[5] = "calculate_mean_variance";
105 }
106
107 std::vector<compute::kernel_t> kernels;
108 status = create_kernels(engine, &kernels, kernel_names, kernel_ctx);
109 CHECK(status);
110
111 kernel_ = kernels[0];
112 calculate_mean_kernel_ = kernels[1];
113 calculate_variance_kernel_ = kernels[2];
114 reduce_mean_kernel_ = kernels[3];
115 reduce_variance_kernel_ = kernels[4];
116 calculate_mean_variance_kernel_ = kernels[5];
117
118 return status::success;
119 }
120
121 status_t execute(const exec_ctx_t &ctx) const override {
122 return execute_forward(ctx);
123 }
124
125private:
126 status_t execute_forward(const exec_ctx_t &ctx) const;
127 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
128 compute::kernel_t kernel_;
129 compute::kernel_t calculate_mean_kernel_;
130 compute::kernel_t reduce_mean_kernel_;
131 compute::kernel_t calculate_variance_kernel_;
132 compute::kernel_t reduce_variance_kernel_;
133 compute::kernel_t calculate_mean_variance_kernel_;
134};
135
136struct ref_batch_normalization_bwd_t : public gpu_primitive_t {
137 using gpu_primitive_t::gpu_primitive_t;
138 struct pd_t : public gpu_batch_normalization_bwd_pd_t {
139 pd_t(const batch_normalization_desc_t *adesc,
140 const primitive_attr_t *attr,
141 const batch_normalization_fwd_pd_t *hint_fwd_pd)
142 : gpu_batch_normalization_bwd_pd_t(adesc, attr, hint_fwd_pd) {}
143
144 DECLARE_COMMON_PD_T("ocl:ref:any", ref_batch_normalization_bwd_t);
145
146 status_t init(engine_t *engine) {
147 using namespace data_type;
148
149 bool ok = !is_fwd() && utils::one_of(src_md()->data_type, f32, bf16)
150 && src_md()->data_type == diff_src_md()->data_type
151 && diff_src_md()->data_type == diff_dst_md()->data_type
152 && check_scale_shift_data_type()
153 && attr()->has_default_values()
154 && set_default_formats_common()
155 && memory_desc_wrapper(diff_src_md())
156 == memory_desc_wrapper(diff_dst_md());
157 if (!ok) return status::unimplemented;
158
159 if (fuse_norm_relu() || fuse_norm_add_relu()) {
160 init_default_ws(8);
161 if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
162 }
163
164 status_t status = init_conf(engine);
165 if (status != status::success) return status;
166 init_scratchpad();
167
168 return status::success;
169 }
170
171 status_t init_conf(engine_t *engine);
172 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
173 void init_scratchpad();
174
175 bnorm_conf_t conf;
176 offsets_t off;
177 };
178
179 status_t init(engine_t *engine) override {
180 compute::kernel_ctx_t kernel_ctx;
181
182 status_t status = pd()->init_kernel_ctx(kernel_ctx);
183 CHECK(status);
184
185 std::vector<const char *> kernel_names
186 = {"ref_bnorm_bwd", "calculate_stats", "reduce_stats"};
187
188 std::vector<compute::kernel_t> kernels;
189 status = create_kernels(engine, &kernels, kernel_names, kernel_ctx);
190 CHECK(status);
191
192 kernel_ = kernels[0];
193 calculate_stats_kernel_ = kernels[1];
194 reduce_stats_kernel_ = kernels[2];
195
196 return status::success;
197 }
198
199 status_t execute(const exec_ctx_t &ctx) const override {
200 return execute_backward(ctx);
201 }
202
203private:
204 status_t execute_backward(const exec_ctx_t &ctx) const;
205 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
206 compute::kernel_t kernel_;
207 compute::kernel_t calculate_stats_kernel_;
208 compute::kernel_t reduce_stats_kernel_;
209};
210
211} // namespace ocl
212} // namespace gpu
213} // namespace impl
214} // namespace dnnl
215
216#endif
217