1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_SIMPLE_LAYER_NORMALIZATION_HPP
18#define CPU_SIMPLE_LAYER_NORMALIZATION_HPP
19
20#include <memory>
21
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/memory_tracking.hpp"
25#include "common/primitive.hpp"
26#include "common/reorder_pd.hpp"
27#include "common/stream.hpp"
28#include "common/utils.hpp"
29
30#include "cpu/cpu_layer_normalization_pd.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35
36struct simple_layer_normalization_fwd_t : public primitive_t {
37 struct pd_t : public cpu_layer_normalization_fwd_pd_t {
38 using cpu_layer_normalization_fwd_pd_t::
39 cpu_layer_normalization_fwd_pd_t;
40
41 DECLARE_COMMON_PD_T("simple:any", simple_layer_normalization_fwd_t);
42
43 status_t init(engine_t *engine);
44
45 bool use_tmp_stats() const { return reorder_pd_ || stats_are_tmp(); }
46
47 std::shared_ptr<primitive_desc_t> reorder_pd_;
48 memory_desc_t reordered_stat_md_;
49
50 private:
51 void init_scratchpad() {
52 using namespace memory_tracking::names;
53 auto scratchpad = scratchpad_registry().registrar();
54 if (use_tmp_stats()) {
55 scratchpad.template book<float>(
56 key_lnorm_tmp_mean, across_axis());
57 scratchpad.template book<float>(
58 key_lnorm_tmp_var, across_axis());
59 }
60 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
61 scratchpad.book(key_nested, reorder_pd_->scratchpad_registry());
62 }
63 }
64 };
65
66 status_t init(engine_t *engine) override {
67 if (pd()->reorder_pd_)
68 pd()->reorder_pd_->create_primitive(reorder_, engine);
69 return status::success;
70 }
71
72 simple_layer_normalization_fwd_t(const pd_t *apd) : primitive_t(apd) {}
73
74 void reorder_stat(const exec_ctx_t &ctx, engine_t *engine,
75 const memory_arg_t &in, const memory_arg_t &out) const {
76 using namespace memory_tracking::names;
77 exec_args_t r_args;
78 r_args[DNNL_ARG_SRC] = in;
79 r_args[DNNL_ARG_DST] = out;
80 exec_ctx_t r_ctx(ctx, std::move(r_args));
81
82 nested_scratchpad_t ns(ctx, key_nested, reorder_);
83 r_ctx.set_scratchpad_grantor(ns.grantor());
84 reorder_->execute(r_ctx);
85 }
86
87 status_t execute(const exec_ctx_t &ctx) const override {
88 /* LN supports arbitrary layout for input/output statistics.
89 * For best performance we compute LN with statistics in the same format
90 * as data tensor (i.e. data in abcd, stats in abc) and user's
91 * input/output statistics are reordered if necessary */
92 using namespace memory_tracking::names;
93 engine_t *engine = ctx.stream()->engine();
94 auto scratchpad = ctx.get_scratchpad_grantor();
95 auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean);
96 auto variance_mem = scratchpad.get_memory_storage(key_lnorm_tmp_var);
97 memory_t mean(engine, &(pd()->reordered_stat_md_), std::move(mean_mem));
98 memory_t variance(
99 engine, &(pd()->reordered_stat_md_), std::move(variance_mem));
100
101 // reorder input stats
102 if (pd()->stats_are_src() && reorder_) {
103 reorder_stat(
104 ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false});
105 reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE),
106 {&variance, false});
107 }
108 status_t status = execute_forward(ctx);
109 if (status != status::success) return status;
110 // reorder output stats
111 if (!pd()->stats_are_src() && reorder_) {
112 reorder_stat(
113 ctx, engine, {&mean, true}, ctx.args().at(DNNL_ARG_MEAN));
114 reorder_stat(ctx, engine, {&variance, true},
115 ctx.args().at(DNNL_ARG_VARIANCE));
116 }
117
118 return status::success;
119 }
120
121private:
122 status_t execute_forward(const exec_ctx_t &ctx) const;
123 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
124
125 std::shared_ptr<primitive_t> reorder_;
126};
127
128struct simple_layer_normalization_bwd_t : public primitive_t {
129 struct pd_t : public cpu_layer_normalization_bwd_pd_t {
130 using cpu_layer_normalization_bwd_pd_t::
131 cpu_layer_normalization_bwd_pd_t;
132
133 DECLARE_COMMON_PD_T("simple:any", simple_layer_normalization_bwd_t);
134
135 status_t init(engine_t *engine);
136
137 bool use_tmp_stats() const { return reorder_pd_.get(); }
138
139 std::shared_ptr<primitive_desc_t> reorder_pd_;
140 memory_desc_t reordered_stat_md_;
141 int nthr_; // To not exceed the limit in execute used for set up.
142
143 private:
144 void init_scratchpad() {
145 using namespace memory_tracking::names;
146 auto scratchpad = scratchpad_registry().registrar();
147 if (use_tmp_stats()) {
148 scratchpad.template book<float>(
149 key_lnorm_tmp_mean, across_axis());
150 scratchpad.template book<float>(
151 key_lnorm_tmp_var, across_axis());
152 }
153 scratchpad.template book<float>(
154 key_lnorm_reduction, 2 * norm_axis() * nthr_);
155 scratchpad.template book<float>(
156 key_lnorm_tmp_diff_ss, 2 * norm_axis());
157 if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
158 scratchpad.book(key_nested, reorder_pd_->scratchpad_registry());
159 }
160 scratchpad.template book<float>(
161 key_lnorm_inv_sqrtvar, across_axis());
162 }
163 };
164
165 status_t init(engine_t *engine) override {
166 if (pd()->reorder_pd_)
167 pd()->reorder_pd_->create_primitive(reorder_, engine);
168 return status::success;
169 }
170
171 simple_layer_normalization_bwd_t(const pd_t *apd) : primitive_t(apd) {}
172
173 void reorder_stat(const exec_ctx_t &ctx, engine_t *engine,
174 const memory_arg_t &in, const memory_arg_t &out) const {
175 using namespace memory_tracking::names;
176 exec_args_t r_args;
177 r_args[DNNL_ARG_SRC] = in;
178 r_args[DNNL_ARG_DST] = out;
179 exec_ctx_t r_ctx(ctx, std::move(r_args));
180
181 nested_scratchpad_t ns(ctx, key_nested, reorder_);
182 r_ctx.set_scratchpad_grantor(ns.grantor());
183 reorder_->execute(r_ctx);
184 }
185
186 status_t execute(const exec_ctx_t &ctx) const override {
187 using namespace memory_tracking::names;
188 /* LN supports arbitrary layout for input/output statistics.
189 * For best performance we compute LN with statistics in the same format
190 * as data tensor (i.e. data in abcd, stats in abc) and user's
191 * input/output statistics are reordered if necessary */
192
193 if (reorder_) {
194 engine_t *engine = ctx.stream()->engine();
195 auto scratchpad = ctx.get_scratchpad_grantor();
196 auto mean_mem = scratchpad.get_memory_storage(key_lnorm_tmp_mean);
197 auto variance_mem
198 = scratchpad.get_memory_storage(key_lnorm_tmp_var);
199 memory_t mean(
200 engine, &(pd()->reordered_stat_md_), std::move(mean_mem));
201 memory_t variance(engine, &(pd()->reordered_stat_md_),
202 std::move(variance_mem));
203 reorder_stat(
204 ctx, engine, ctx.args().at(DNNL_ARG_MEAN), {&mean, false});
205 reorder_stat(ctx, engine, ctx.args().at(DNNL_ARG_VARIANCE),
206 {&variance, false});
207 }
208
209 return execute_backward(ctx);
210 }
211
212private:
213 status_t execute_backward(const exec_ctx_t &ctx) const;
214 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
215
216 std::shared_ptr<primitive_t> reorder_;
217};
218
219} // namespace cpu
220} // namespace impl
221} // namespace dnnl
222
223#endif
224
225// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
226