1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_LRN_AVX512_BLOCKED_EXECUTOR_HPP
18#define CPU_X64_LRN_JIT_LRN_AVX512_BLOCKED_EXECUTOR_HPP
19
20#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_blocked.hpp"
21#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_blocked.hpp"
22#include "cpu/x64/lrn/lrn_executor.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28namespace lrn {
29
30template <::dnnl::impl::data_type_t d_type, typename PD_T>
31class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t {
32public:
33 lrn_avx512_blocked_executor_fwd_t(const PD_T *pd)
34 : ker_(nullptr)
35 , ker_first_(nullptr)
36 , ker_last_(nullptr)
37 , N_(pd->MB())
38 , C_(pd->C())
39 , H_(pd->H())
40 , W_(pd->W())
41 , use_h_parallelism_(H_ > 28 ? 1 : 0) {
42
43 const int local_size = pd->desc()->local_size;
44 const float alpha = pd->desc()->lrn_alpha / local_size;
45 const float beta = pd->desc()->lrn_beta;
46 const auto pk = pd->desc()->prop_kind;
47 const float k = pd->desc()->lrn_k;
48
49 if (C_ / vsize_ == 1) {
50 ker_ = utils::make_unique<
51 lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>(
52 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Single),
53 pk, use_h_parallelism_, alpha, beta, k, local_size);
54 } else {
55 ker_ = utils::make_unique<
56 lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>(
57 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Middle),
58 pk, use_h_parallelism_, alpha, beta, k, local_size);
59 ker_first_ = utils::make_unique<
60 lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>(
61 lrn::nChw16c_across_t(H_, W_, lrn::across_version::First),
62 pk, use_h_parallelism_, alpha, beta, k, local_size);
63 ker_last_ = utils::make_unique<
64 lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>(
65 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Last),
66 pk, use_h_parallelism_, alpha, beta, k, local_size);
67 }
68 }
69
70 using data_t = typename prec_traits<d_type>::type;
71
72 status_t create_kernel() override {
73 CHECK(ker_->create_kernel());
74 if (ker_first_) CHECK(ker_first_->create_kernel());
75 if (ker_last_) CHECK(ker_last_->create_kernel());
76 return status::success;
77 }
78
79 status_t execute(const exec_ctx_t &ctx) const override {
80 status_t status = status::success;
81 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
82 const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
83 CHECK(status);
84 const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status);
85 CHECK(status);
86
87 const auto ker = ker_.get();
88 const auto ker_first = ker_first_.get();
89 const auto ker_last = ker_last_.get();
90
91 parallel(0, [&](const int ithr, const int nthr) {
92 size_t start {0}, end {0};
93 const int C16 = C_ / vsize_;
94 const size_t work_amount
95 = use_h_parallelism_ ? N_ * C16 * H_ : N_ * C16;
96
97 balance211(work_amount, nthr, ithr, start, end);
98 if (use_h_parallelism_) {
99 int n {0}, c16 {0}, h {0};
100 nd_iterator_init(start, n, N_, c16, C16, h, H_);
101 for (size_t iwork = start; iwork < end; ++iwork) {
102 const auto offset = n * C_ * H_ * W_
103 + c16 * H_ * W_ * vsize_ + h * W_ * vsize_;
104 const auto ws_offset0 = n * C_ * H_ * 2 * W_
105 + c16 * H_ * 2 * W_ * vsize_ + h * 2 * W_ * vsize_;
106 const auto ws_offset1 = ws_offset0 + W_ * vsize_;
107
108 typename lrn::jit_avx512_common_lrn_kernel_fwd_t<
109 d_type>::jit_args_fwd_t args;
110 args.src = &src[offset];
111 args.dst = &dst[offset];
112 args.ws0 = ws ? &ws[ws_offset0] : nullptr;
113 args.ws1 = ws ? &ws[ws_offset1] : nullptr;
114
115 if (C16 == 1)
116 (*ker)(&args);
117 else if (c16 == 0)
118 (*ker_first)(&args);
119 else if (c16 == C16 - 1)
120 (*ker_last)(&args);
121 else
122 (*ker)(&args);
123 nd_iterator_step(n, N_, c16, C16, h, H_);
124 }
125 } else {
126 int n {0}, c16 {0};
127 nd_iterator_init(start, n, N_, c16, C16);
128 for (size_t iwork = start; iwork < end; ++iwork) {
129 const auto offset
130 = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_;
131 const auto ws_offset0
132 = n * C_ * H_ * 2 * W_ + c16 * H_ * 2 * W_ * vsize_;
133 const auto ws_offset1 = ws_offset0 + H_ * W_ * vsize_;
134
135 typename lrn::jit_avx512_common_lrn_kernel_fwd_t<
136 d_type>::jit_args_fwd_t args;
137 args.src = &src[offset];
138 args.dst = &dst[offset];
139 args.ws0 = ws ? &ws[ws_offset0] : nullptr;
140 args.ws1 = ws ? &ws[ws_offset1] : nullptr;
141
142 if (C16 == 1)
143 (*ker)(&args);
144 else if (c16 == 0)
145 (*ker_first)(&args);
146 else if (c16 == C16 - 1)
147 (*ker_last)(&args);
148 else
149 (*ker)(&args);
150
151 nd_iterator_step(n, N_, c16, C16);
152 }
153 }
154 });
155
156 return status::success;
157 }
158
159private:
160 std::unique_ptr<lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>
161 ker_, ker_first_, ker_last_;
162 static constexpr int vsize_ = 16;
163 const int N_;
164 const int C_;
165 const int H_;
166 const int W_;
167 const int use_h_parallelism_;
168};
169
170template <::dnnl::impl::data_type_t d_type, typename PD_T>
171class lrn_avx512_blocked_executor_bwd_t : public i_lrn_executor_t {
172public:
173 lrn_avx512_blocked_executor_bwd_t(const PD_T *pd)
174 : ker_(nullptr)
175 , ker_first_(nullptr)
176 , ker_last_(nullptr)
177 , N_(pd->MB())
178 , C_(pd->C())
179 , H_(pd->H())
180 , W_(pd->W())
181 , use_h_parallelism_(H_ > 28 ? 1 : 0) {
182
183 const int local_size = pd->desc()->local_size;
184 const float alpha = pd->desc()->lrn_alpha / local_size;
185 const float beta = pd->desc()->lrn_beta;
186
187 if (C_ / vsize_ == 1) {
188 ker_ = utils::make_unique<
189 lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>(
190 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Single),
191 alpha, beta, local_size, use_h_parallelism_);
192 } else {
193 ker_ = utils::make_unique<
194 lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>(
195 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Middle),
196 alpha, beta, local_size, use_h_parallelism_);
197 ker_first_ = utils::make_unique<
198 lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>(
199 lrn::nChw16c_across_t(H_, W_, lrn::across_version::First),
200 alpha, beta, local_size, use_h_parallelism_);
201 ker_last_ = utils::make_unique<
202 lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>(
203 lrn::nChw16c_across_t(H_, W_, lrn::across_version::Last),
204 alpha, beta, local_size, use_h_parallelism_);
205 }
206 }
207
208 using data_t = typename prec_traits<d_type>::type;
209
210 status_t create_kernel() override {
211 CHECK(ker_->create_kernel());
212 if (ker_first_) CHECK(ker_first_->create_kernel());
213 if (ker_last_) CHECK(ker_last_->create_kernel());
214 return status::success;
215 }
216
217 status_t execute(const exec_ctx_t &ctx) const override {
218 status_t status = status::success;
219 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
220 const auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
221 const auto ws = CTX_IN_MEM(const data_t *, DNNL_ARG_WORKSPACE);
222 const auto diff_src
223 = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
224 CHECK(status);
225
226 const auto ker = ker_.get();
227 const auto ker_first = ker_first_.get();
228 const auto ker_last = ker_last_.get();
229
230 parallel(0, [&](const int ithr, const int nthr) {
231 size_t start {0}, end {0};
232 const int C16 = C_ / vsize_;
233 const size_t work_amount
234 = use_h_parallelism_ ? N_ * C16 * H_ : N_ * C16;
235
236 balance211(work_amount, nthr, ithr, start, end);
237 if (use_h_parallelism_) {
238 int n {0}, c16 {0}, h {0};
239 nd_iterator_init(start, n, N_, h, H_, c16, C16);
240 for (size_t iwork = start; iwork < end; ++iwork) {
241 const auto offset = n * C_ * H_ * W_
242 + c16 * H_ * W_ * vsize_ + h * W_ * vsize_;
243 const auto ws_offset0 = n * C_ * H_ * 2 * W_
244 + c16 * H_ * 2 * W_ * vsize_ + h * 2 * W_ * vsize_;
245 const auto ws_offset1 = ws_offset0 + W_ * vsize_;
246
247 typename lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<
248 d_type>::jit_args_bwd_t args;
249 args.src = &src[offset];
250 args.diff_dst = &diff_dst[offset];
251 args.ws0 = ws ? &ws[ws_offset0] : nullptr;
252 args.ws1 = ws ? &ws[ws_offset1] : nullptr;
253 args.diff_src = &diff_src[offset];
254
255 if (C16 == 1)
256 (*ker)(&args);
257 else if (c16 == 0)
258 (*ker_first)(&args);
259 else if (c16 == C16 - 1)
260 (*ker_last)(&args);
261 else
262 (*ker)(&args);
263 nd_iterator_step(n, N_, h, H_, c16, C16);
264 }
265 } else {
266 int n {0}, c16 {0};
267 nd_iterator_init(start, n, N_, c16, C16);
268 for (size_t iwork = start; iwork < end; ++iwork) {
269 const auto offset
270 = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_;
271 const auto ws_offset0
272 = n * C_ * H_ * 2 * W_ + c16 * H_ * 2 * W_ * vsize_;
273 const auto ws_offset1 = ws_offset0 + H_ * W_ * vsize_;
274
275 typename lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<
276 d_type>::jit_args_bwd_t args;
277 args.src = &src[offset];
278 args.diff_dst = &diff_dst[offset];
279 args.ws0 = ws ? &ws[ws_offset0] : nullptr;
280 args.ws1 = ws ? &ws[ws_offset1] : nullptr;
281 args.diff_src = &diff_src[offset];
282
283 if (C16 == 1)
284 (*ker)(&args);
285 else if (c16 == 0)
286 (*ker_first)(&args);
287 else if (c16 == C16 - 1)
288 (*ker_last)(&args);
289 else
290 (*ker)(&args);
291
292 nd_iterator_step(n, N_, c16, C16);
293 }
294 }
295 });
296
297 return status::success;
298 }
299
300private:
301 std::unique_ptr<lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>
302 ker_, ker_first_, ker_last_;
303 static constexpr int vsize_ = 16;
304 const int N_;
305 const int C_;
306 const int H_;
307 const int W_;
308 const int use_h_parallelism_;
309};
310
311} // namespace lrn
312} // namespace x64
313} // namespace cpu
314} // namespace impl
315} // namespace dnnl
316
317#endif
318