1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_LRN_JIT_LRN_AVX512_BLOCKED_EXECUTOR_HPP |
18 | #define CPU_X64_LRN_JIT_LRN_AVX512_BLOCKED_EXECUTOR_HPP |
19 | |
20 | #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_blocked.hpp" |
21 | #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_blocked.hpp" |
22 | #include "cpu/x64/lrn/lrn_executor.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | namespace lrn { |
29 | |
30 | template <::dnnl::impl::data_type_t d_type, typename PD_T> |
31 | class lrn_avx512_blocked_executor_fwd_t : public i_lrn_executor_t { |
32 | public: |
33 | lrn_avx512_blocked_executor_fwd_t(const PD_T *pd) |
34 | : ker_(nullptr) |
35 | , ker_first_(nullptr) |
36 | , ker_last_(nullptr) |
37 | , N_(pd->MB()) |
38 | , C_(pd->C()) |
39 | , H_(pd->H()) |
40 | , W_(pd->W()) |
41 | , use_h_parallelism_(H_ > 28 ? 1 : 0) { |
42 | |
43 | const int local_size = pd->desc()->local_size; |
44 | const float alpha = pd->desc()->lrn_alpha / local_size; |
45 | const float beta = pd->desc()->lrn_beta; |
46 | const auto pk = pd->desc()->prop_kind; |
47 | const float k = pd->desc()->lrn_k; |
48 | |
49 | if (C_ / vsize_ == 1) { |
50 | ker_ = utils::make_unique< |
51 | lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>( |
52 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Single), |
53 | pk, use_h_parallelism_, alpha, beta, k, local_size); |
54 | } else { |
55 | ker_ = utils::make_unique< |
56 | lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>( |
57 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Middle), |
58 | pk, use_h_parallelism_, alpha, beta, k, local_size); |
59 | ker_first_ = utils::make_unique< |
60 | lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>( |
61 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::First), |
62 | pk, use_h_parallelism_, alpha, beta, k, local_size); |
63 | ker_last_ = utils::make_unique< |
64 | lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>>( |
65 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Last), |
66 | pk, use_h_parallelism_, alpha, beta, k, local_size); |
67 | } |
68 | } |
69 | |
70 | using data_t = typename prec_traits<d_type>::type; |
71 | |
72 | status_t create_kernel() override { |
73 | CHECK(ker_->create_kernel()); |
74 | if (ker_first_) CHECK(ker_first_->create_kernel()); |
75 | if (ker_last_) CHECK(ker_last_->create_kernel()); |
76 | return status::success; |
77 | } |
78 | |
79 | status_t execute(const exec_ctx_t &ctx) const override { |
80 | status_t status = status::success; |
81 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
82 | const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
83 | CHECK(status); |
84 | const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); |
85 | CHECK(status); |
86 | |
87 | const auto ker = ker_.get(); |
88 | const auto ker_first = ker_first_.get(); |
89 | const auto ker_last = ker_last_.get(); |
90 | |
91 | parallel(0, [&](const int ithr, const int nthr) { |
92 | size_t start {0}, end {0}; |
93 | const int C16 = C_ / vsize_; |
94 | const size_t work_amount |
95 | = use_h_parallelism_ ? N_ * C16 * H_ : N_ * C16; |
96 | |
97 | balance211(work_amount, nthr, ithr, start, end); |
98 | if (use_h_parallelism_) { |
99 | int n {0}, c16 {0}, h {0}; |
100 | nd_iterator_init(start, n, N_, c16, C16, h, H_); |
101 | for (size_t iwork = start; iwork < end; ++iwork) { |
102 | const auto offset = n * C_ * H_ * W_ |
103 | + c16 * H_ * W_ * vsize_ + h * W_ * vsize_; |
104 | const auto ws_offset0 = n * C_ * H_ * 2 * W_ |
105 | + c16 * H_ * 2 * W_ * vsize_ + h * 2 * W_ * vsize_; |
106 | const auto ws_offset1 = ws_offset0 + W_ * vsize_; |
107 | |
108 | typename lrn::jit_avx512_common_lrn_kernel_fwd_t< |
109 | d_type>::jit_args_fwd_t args; |
110 | args.src = &src[offset]; |
111 | args.dst = &dst[offset]; |
112 | args.ws0 = ws ? &ws[ws_offset0] : nullptr; |
113 | args.ws1 = ws ? &ws[ws_offset1] : nullptr; |
114 | |
115 | if (C16 == 1) |
116 | (*ker)(&args); |
117 | else if (c16 == 0) |
118 | (*ker_first)(&args); |
119 | else if (c16 == C16 - 1) |
120 | (*ker_last)(&args); |
121 | else |
122 | (*ker)(&args); |
123 | nd_iterator_step(n, N_, c16, C16, h, H_); |
124 | } |
125 | } else { |
126 | int n {0}, c16 {0}; |
127 | nd_iterator_init(start, n, N_, c16, C16); |
128 | for (size_t iwork = start; iwork < end; ++iwork) { |
129 | const auto offset |
130 | = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_; |
131 | const auto ws_offset0 |
132 | = n * C_ * H_ * 2 * W_ + c16 * H_ * 2 * W_ * vsize_; |
133 | const auto ws_offset1 = ws_offset0 + H_ * W_ * vsize_; |
134 | |
135 | typename lrn::jit_avx512_common_lrn_kernel_fwd_t< |
136 | d_type>::jit_args_fwd_t args; |
137 | args.src = &src[offset]; |
138 | args.dst = &dst[offset]; |
139 | args.ws0 = ws ? &ws[ws_offset0] : nullptr; |
140 | args.ws1 = ws ? &ws[ws_offset1] : nullptr; |
141 | |
142 | if (C16 == 1) |
143 | (*ker)(&args); |
144 | else if (c16 == 0) |
145 | (*ker_first)(&args); |
146 | else if (c16 == C16 - 1) |
147 | (*ker_last)(&args); |
148 | else |
149 | (*ker)(&args); |
150 | |
151 | nd_iterator_step(n, N_, c16, C16); |
152 | } |
153 | } |
154 | }); |
155 | |
156 | return status::success; |
157 | } |
158 | |
159 | private: |
160 | std::unique_ptr<lrn::jit_avx512_common_lrn_kernel_fwd_blocked_t<d_type>> |
161 | ker_, ker_first_, ker_last_; |
162 | static constexpr int vsize_ = 16; |
163 | const int N_; |
164 | const int C_; |
165 | const int H_; |
166 | const int W_; |
167 | const int use_h_parallelism_; |
168 | }; |
169 | |
170 | template <::dnnl::impl::data_type_t d_type, typename PD_T> |
171 | class lrn_avx512_blocked_executor_bwd_t : public i_lrn_executor_t { |
172 | public: |
173 | lrn_avx512_blocked_executor_bwd_t(const PD_T *pd) |
174 | : ker_(nullptr) |
175 | , ker_first_(nullptr) |
176 | , ker_last_(nullptr) |
177 | , N_(pd->MB()) |
178 | , C_(pd->C()) |
179 | , H_(pd->H()) |
180 | , W_(pd->W()) |
181 | , use_h_parallelism_(H_ > 28 ? 1 : 0) { |
182 | |
183 | const int local_size = pd->desc()->local_size; |
184 | const float alpha = pd->desc()->lrn_alpha / local_size; |
185 | const float beta = pd->desc()->lrn_beta; |
186 | |
187 | if (C_ / vsize_ == 1) { |
188 | ker_ = utils::make_unique< |
189 | lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>( |
190 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Single), |
191 | alpha, beta, local_size, use_h_parallelism_); |
192 | } else { |
193 | ker_ = utils::make_unique< |
194 | lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>( |
195 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Middle), |
196 | alpha, beta, local_size, use_h_parallelism_); |
197 | ker_first_ = utils::make_unique< |
198 | lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>( |
199 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::First), |
200 | alpha, beta, local_size, use_h_parallelism_); |
201 | ker_last_ = utils::make_unique< |
202 | lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>>( |
203 | lrn::nChw16c_across_t(H_, W_, lrn::across_version::Last), |
204 | alpha, beta, local_size, use_h_parallelism_); |
205 | } |
206 | } |
207 | |
208 | using data_t = typename prec_traits<d_type>::type; |
209 | |
210 | status_t create_kernel() override { |
211 | CHECK(ker_->create_kernel()); |
212 | if (ker_first_) CHECK(ker_first_->create_kernel()); |
213 | if (ker_last_) CHECK(ker_last_->create_kernel()); |
214 | return status::success; |
215 | } |
216 | |
217 | status_t execute(const exec_ctx_t &ctx) const override { |
218 | status_t status = status::success; |
219 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
220 | const auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST); |
221 | const auto ws = CTX_IN_MEM(const data_t *, DNNL_ARG_WORKSPACE); |
222 | const auto diff_src |
223 | = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
224 | CHECK(status); |
225 | |
226 | const auto ker = ker_.get(); |
227 | const auto ker_first = ker_first_.get(); |
228 | const auto ker_last = ker_last_.get(); |
229 | |
230 | parallel(0, [&](const int ithr, const int nthr) { |
231 | size_t start {0}, end {0}; |
232 | const int C16 = C_ / vsize_; |
233 | const size_t work_amount |
234 | = use_h_parallelism_ ? N_ * C16 * H_ : N_ * C16; |
235 | |
236 | balance211(work_amount, nthr, ithr, start, end); |
237 | if (use_h_parallelism_) { |
238 | int n {0}, c16 {0}, h {0}; |
239 | nd_iterator_init(start, n, N_, h, H_, c16, C16); |
240 | for (size_t iwork = start; iwork < end; ++iwork) { |
241 | const auto offset = n * C_ * H_ * W_ |
242 | + c16 * H_ * W_ * vsize_ + h * W_ * vsize_; |
243 | const auto ws_offset0 = n * C_ * H_ * 2 * W_ |
244 | + c16 * H_ * 2 * W_ * vsize_ + h * 2 * W_ * vsize_; |
245 | const auto ws_offset1 = ws_offset0 + W_ * vsize_; |
246 | |
247 | typename lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t< |
248 | d_type>::jit_args_bwd_t args; |
249 | args.src = &src[offset]; |
250 | args.diff_dst = &diff_dst[offset]; |
251 | args.ws0 = ws ? &ws[ws_offset0] : nullptr; |
252 | args.ws1 = ws ? &ws[ws_offset1] : nullptr; |
253 | args.diff_src = &diff_src[offset]; |
254 | |
255 | if (C16 == 1) |
256 | (*ker)(&args); |
257 | else if (c16 == 0) |
258 | (*ker_first)(&args); |
259 | else if (c16 == C16 - 1) |
260 | (*ker_last)(&args); |
261 | else |
262 | (*ker)(&args); |
263 | nd_iterator_step(n, N_, h, H_, c16, C16); |
264 | } |
265 | } else { |
266 | int n {0}, c16 {0}; |
267 | nd_iterator_init(start, n, N_, c16, C16); |
268 | for (size_t iwork = start; iwork < end; ++iwork) { |
269 | const auto offset |
270 | = n * C_ * H_ * W_ + c16 * H_ * W_ * vsize_; |
271 | const auto ws_offset0 |
272 | = n * C_ * H_ * 2 * W_ + c16 * H_ * 2 * W_ * vsize_; |
273 | const auto ws_offset1 = ws_offset0 + H_ * W_ * vsize_; |
274 | |
275 | typename lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t< |
276 | d_type>::jit_args_bwd_t args; |
277 | args.src = &src[offset]; |
278 | args.diff_dst = &diff_dst[offset]; |
279 | args.ws0 = ws ? &ws[ws_offset0] : nullptr; |
280 | args.ws1 = ws ? &ws[ws_offset1] : nullptr; |
281 | args.diff_src = &diff_src[offset]; |
282 | |
283 | if (C16 == 1) |
284 | (*ker)(&args); |
285 | else if (c16 == 0) |
286 | (*ker_first)(&args); |
287 | else if (c16 == C16 - 1) |
288 | (*ker_last)(&args); |
289 | else |
290 | (*ker)(&args); |
291 | |
292 | nd_iterator_step(n, N_, c16, C16); |
293 | } |
294 | } |
295 | }); |
296 | |
297 | return status::success; |
298 | } |
299 | |
300 | private: |
301 | std::unique_ptr<lrn::jit_avx512_common_lrn_kernel_bwd_blocked_t<d_type>> |
302 | ker_, ker_first_, ker_last_; |
303 | static constexpr int vsize_ = 16; |
304 | const int N_; |
305 | const int C_; |
306 | const int H_; |
307 | const int W_; |
308 | const int use_h_parallelism_; |
309 | }; |
310 | |
311 | } // namespace lrn |
312 | } // namespace x64 |
313 | } // namespace cpu |
314 | } // namespace impl |
315 | } // namespace dnnl |
316 | |
317 | #endif |
318 | |