1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP |
18 | #define CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP |
19 | |
20 | #include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp" |
21 | #include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp" |
22 | #include "cpu/x64/lrn/lrn_executor.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | namespace lrn { |
29 | |
30 | template <::dnnl::impl::data_type_t d_type, typename PD_T> |
31 | class lrn_avx512_nhwc_executor_fwd_t : public i_lrn_executor_t { |
32 | public: |
33 | lrn_avx512_nhwc_executor_fwd_t(const PD_T *pd) |
34 | : ker_(utils::make_unique< |
35 | lrn::jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>>(pd->C(), |
36 | pd->desc()->prop_kind, |
37 | pd->desc()->lrn_alpha / pd->desc()->local_size, |
38 | pd->desc()->lrn_beta, pd->desc()->lrn_k, |
39 | pd->desc()->local_size)) |
40 | , N_(pd->MB()) |
41 | , C_(pd->C()) |
42 | , H_(pd->H()) |
43 | , W_(pd->W()) {} |
44 | |
45 | using data_t = typename prec_traits<d_type>::type; |
46 | |
47 | status_t create_kernel() override { return ker_->create_kernel(); } |
48 | |
49 | status_t execute(const exec_ctx_t &ctx) const override { |
50 | status_t status = status::success; |
51 | const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); |
52 | const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status); |
53 | CHECK(status); |
54 | const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status); |
55 | CHECK(status); |
56 | |
57 | const auto ker = ker_.get(); |
58 | parallel_nd(N_, H_ * W_, [&](dim_t n, dim_t pixel_id) { |
59 | typename lrn::jit_avx512_common_lrn_kernel_fwd_t< |
60 | d_type>::jit_args_fwd_t args; |
61 | const auto offset = n * C_ * H_ * W_ + pixel_id * C_; |
62 | const auto ws_offset0 = offset * 2; |
63 | const auto ws_offset1 = ws_offset0 + C_; |
64 | |
65 | args.src = &src[offset]; |
66 | args.dst = &dst[offset]; |
67 | args.ws0 = ws ? &ws[ws_offset0] : nullptr; |
68 | args.ws1 = ws ? &ws[ws_offset1] : nullptr; |
69 | |
70 | (*ker)(&args); |
71 | }); |
72 | |
73 | return status::success; |
74 | } |
75 | |
76 | virtual ~lrn_avx512_nhwc_executor_fwd_t() = default; |
77 | |
78 | private: |
79 | std::unique_ptr<jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>> ker_; |
80 | const int N_; |
81 | const int C_; |
82 | const int H_; |
83 | const int W_; |
84 | }; |
85 | template <::dnnl::impl::data_type_t d_type, typename PD_T> |
86 | class lrn_avx512_nhwc_executor_bwd_t : public i_lrn_executor_t { |
87 | public: |
88 | lrn_avx512_nhwc_executor_bwd_t(const PD_T *pd) |
89 | : ker_ {utils::make_unique< |
90 | lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>>(pd->C(), |
91 | pd->desc()->lrn_alpha / pd->desc()->local_size, |
92 | pd->desc()->lrn_beta, pd->desc()->local_size)} |
93 | , N_(pd->MB()) |
94 | , C_(pd->C()) |
95 | , H_(pd->H()) |
96 | , W_(pd->W()) {} |
97 | using data_t = typename prec_traits<d_type>::type; |
98 | |
99 | status_t create_kernel() override { return ker_->create_kernel(); } |
100 | |
101 | status_t execute(const exec_ctx_t &ctx) const override { |
102 | status_t status = status::success; |
103 | auto src = CTX_IN_MEM(data_t *, DNNL_ARG_SRC); |
104 | auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status); |
105 | CHECK(status); |
106 | auto diff_dst = CTX_IN_MEM(data_t *, DNNL_ARG_DIFF_DST); |
107 | auto ws = CTX_IN_MEM(data_t *, DNNL_ARG_WORKSPACE); |
108 | |
109 | const auto ker = ker_.get(); |
110 | parallel_nd(N_, H_ * W_, [&](dim_t n, dim_t pixel_id) { |
111 | typename lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t< |
112 | d_type>::jit_args_bwd_t args; |
113 | const auto offset = n * C_ * H_ * W_ + pixel_id * C_; |
114 | const auto ws_offset0 = offset * 2; |
115 | const auto ws_offset1 = ws_offset0 + C_; |
116 | |
117 | args.src = &src[offset]; |
118 | args.diff_dst = &diff_dst[offset]; |
119 | args.ws0 = &ws[ws_offset0]; |
120 | args.ws1 = &ws[ws_offset1]; |
121 | args.diff_src = &diff_src[offset]; |
122 | |
123 | (*ker)(&args); |
124 | }); |
125 | |
126 | return status::success; |
127 | } |
128 | |
129 | virtual ~lrn_avx512_nhwc_executor_bwd_t() = default; |
130 | |
131 | private: |
132 | std::unique_ptr<jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>> ker_; |
133 | const int N_; |
134 | const int C_; |
135 | const int H_; |
136 | const int W_; |
137 | }; |
138 | |
139 | } // namespace lrn |
140 | } // namespace x64 |
141 | } // namespace cpu |
142 | } // namespace impl |
143 | } // namespace dnnl |
144 | |
145 | #endif |
146 | |