1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_OCL_REF_LAYER_NORMALIZATION_HPP
18#define GPU_OCL_REF_LAYER_NORMALIZATION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23#include "common/utils.hpp"
24#include "gpu/compute/compute.hpp"
25#include "gpu/gpu_layer_normalization_pd.hpp"
26#include "gpu/gpu_primitive.hpp"
27#include "gpu/gpu_resource.hpp"
28#include "gpu/primitive_conf.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace gpu {
33namespace ocl {
34
35struct ref_layer_normalization_fwd_t : public gpu_primitive_t {
36 using gpu_primitive_t::gpu_primitive_t;
37 struct pd_t : public gpu_layer_normalization_fwd_pd_t {
38 using gpu_layer_normalization_fwd_pd_t::
39 gpu_layer_normalization_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("lnorm_ref:any", ref_layer_normalization_fwd_t);
42
43 status_t init(engine_t *engine) {
44 using namespace data_type;
45
46 auto src_data_t = src_md()->data_type;
47 auto dst_data_t = dst_md()->data_type;
48
49 bool ok = is_fwd()
50 && (utils::everyone_is(f16, src_data_t, dst_data_t)
51 || utils::everyone_is(bf16, src_data_t, dst_data_t)
52 || utils::everyone_is(f32, src_data_t, dst_data_t))
53 && !memory_desc_ndims_ok(src_md(), dst_md(), stat_md())
54 && stat_md()->data_type == f32
55 && check_scale_shift_data_type()
56 && attr()->has_default_values()
57 && set_default_formats_common();
58 if (!ok) return status::unimplemented;
59
60 return init_conf(engine);
61 }
62
63 status_t init_conf(engine_t *engine);
64 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
65
66 lnorm_conf_t conf;
67 };
68
69 status_t init(engine_t *engine) override {
70 if (pd()->has_zero_dim_memory()) return status::success;
71
72 compute::kernel_ctx_t kernel_ctx;
73
74 status_t status = pd()->init_kernel_ctx(kernel_ctx);
75 CHECK(status);
76
77 create_kernel(engine, &kernel_, "ref_lnorm_fwd", kernel_ctx);
78 if (!kernel_) return status::runtime_error;
79
80 return status::success;
81 }
82
83 status_t execute(const exec_ctx_t &ctx) const override {
84 return execute_forward(ctx);
85 }
86
87private:
88 status_t execute_forward(const exec_ctx_t &ctx) const;
89 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
90
91 compute::kernel_t kernel_;
92};
93
94struct ref_layer_normalization_bwd_t : public gpu_primitive_t {
95 using gpu_primitive_t::gpu_primitive_t;
96 struct pd_t : public gpu_layer_normalization_bwd_pd_t {
97 using gpu_layer_normalization_bwd_pd_t::
98 gpu_layer_normalization_bwd_pd_t;
99
100 DECLARE_COMMON_PD_T("lnorm_ref:any", ref_layer_normalization_bwd_t);
101
102 status_t init(engine_t *engine) {
103 using namespace data_type;
104
105 auto src_dt = src_md()->data_type;
106 auto diff_dst_dt = diff_dst_md()->data_type;
107 auto diff_src_dt = diff_src_md()->data_type;
108
109 bool ok = is_bwd()
110 && (utils::everyone_is(
111 f32, src_dt, diff_dst_dt, diff_src_dt)
112 || utils::everyone_is(
113 bf16, src_dt, diff_dst_dt, diff_src_dt))
114 && stat_md()->data_type == f32
115 && check_scale_shift_data_type()
116 && attr()->has_default_values()
117 && set_default_formats_common();
118 if (!ok) return status::unimplemented;
119
120 CHECK(init_conf(engine));
121 if (conf.vectorize_bwd_scaleshift) { init_scratchpad(); }
122 return status::success;
123 }
124
125 status_t init_conf(engine_t *engine);
126 status_t init_kernel_ctx(compute::kernel_ctx_t &kernel_ctx) const;
127 void init_scratchpad();
128
129 lnorm_conf_t conf;
130 };
131
132 status_t init(engine_t *engine) override {
133 if (pd()->has_zero_dim_memory()) return status::success;
134
135 compute::kernel_ctx_t kernel_ctx;
136
137 status_t status = pd()->init_kernel_ctx(kernel_ctx);
138 CHECK(status);
139
140 create_kernel(engine, &kernel_, "ref_lnorm_bwd", kernel_ctx);
141 if (pd()->conf.use_scale || pd()->conf.use_shift) {
142 create_kernel(engine, &kernel_scaleshift_,
143 "ref_lnorm_bwd_scaleshift", kernel_ctx);
144 if (!kernel_scaleshift_) return status::runtime_error;
145 if (pd()->conf.vectorize_bwd_scaleshift) {
146 create_kernel(engine, &kernel_scaleshift_finalize_,
147 "ref_lnorm_bwd_scaleshift_final", kernel_ctx);
148 if (!kernel_scaleshift_finalize_) return status::runtime_error;
149 }
150 }
151 if (!kernel_) return status::runtime_error;
152
153 return status::success;
154 }
155
156 status_t execute(const exec_ctx_t &ctx) const override {
157 return execute_backward(ctx);
158 }
159
160private:
161 status_t execute_backward(const exec_ctx_t &ctx) const;
162 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
163
164 compute::kernel_t kernel_scaleshift_;
165 compute::kernel_t kernel_scaleshift_finalize_;
166 compute::kernel_t kernel_;
167};
168
169} // namespace ocl
170} // namespace gpu
171} // namespace impl
172} // namespace dnnl
173
174#endif
175