1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP
18#define CPU_X64_LRN_JIT_LRN_AVX512_NHWC_EXECUTOR_HPP
19
20#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_nhwc.hpp"
21#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_nhwc.hpp"
22#include "cpu/x64/lrn/lrn_executor.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28namespace lrn {
29
30template <::dnnl::impl::data_type_t d_type, typename PD_T>
31class lrn_avx512_nhwc_executor_fwd_t : public i_lrn_executor_t {
32public:
33 lrn_avx512_nhwc_executor_fwd_t(const PD_T *pd)
34 : ker_(utils::make_unique<
35 lrn::jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>>(pd->C(),
36 pd->desc()->prop_kind,
37 pd->desc()->lrn_alpha / pd->desc()->local_size,
38 pd->desc()->lrn_beta, pd->desc()->lrn_k,
39 pd->desc()->local_size))
40 , N_(pd->MB())
41 , C_(pd->C())
42 , H_(pd->H())
43 , W_(pd->W()) {}
44
45 using data_t = typename prec_traits<d_type>::type;
46
47 status_t create_kernel() override { return ker_->create_kernel(); }
48
49 status_t execute(const exec_ctx_t &ctx) const override {
50 status_t status = status::success;
51 const auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
52 const auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
53 CHECK(status);
54 const auto ws = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_WORKSPACE, status);
55 CHECK(status);
56
57 const auto ker = ker_.get();
58 parallel_nd(N_, H_ * W_, [&](dim_t n, dim_t pixel_id) {
59 typename lrn::jit_avx512_common_lrn_kernel_fwd_t<
60 d_type>::jit_args_fwd_t args;
61 const auto offset = n * C_ * H_ * W_ + pixel_id * C_;
62 const auto ws_offset0 = offset * 2;
63 const auto ws_offset1 = ws_offset0 + C_;
64
65 args.src = &src[offset];
66 args.dst = &dst[offset];
67 args.ws0 = ws ? &ws[ws_offset0] : nullptr;
68 args.ws1 = ws ? &ws[ws_offset1] : nullptr;
69
70 (*ker)(&args);
71 });
72
73 return status::success;
74 }
75
76 virtual ~lrn_avx512_nhwc_executor_fwd_t() = default;
77
78private:
79 std::unique_ptr<jit_avx512_common_lrn_kernel_fwd_nhwc_t<d_type>> ker_;
80 const int N_;
81 const int C_;
82 const int H_;
83 const int W_;
84};
85template <::dnnl::impl::data_type_t d_type, typename PD_T>
86class lrn_avx512_nhwc_executor_bwd_t : public i_lrn_executor_t {
87public:
88 lrn_avx512_nhwc_executor_bwd_t(const PD_T *pd)
89 : ker_ {utils::make_unique<
90 lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>>(pd->C(),
91 pd->desc()->lrn_alpha / pd->desc()->local_size,
92 pd->desc()->lrn_beta, pd->desc()->local_size)}
93 , N_(pd->MB())
94 , C_(pd->C())
95 , H_(pd->H())
96 , W_(pd->W()) {}
97 using data_t = typename prec_traits<d_type>::type;
98
99 status_t create_kernel() override { return ker_->create_kernel(); }
100
101 status_t execute(const exec_ctx_t &ctx) const override {
102 status_t status = status::success;
103 auto src = CTX_IN_MEM(data_t *, DNNL_ARG_SRC);
104 auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
105 CHECK(status);
106 auto diff_dst = CTX_IN_MEM(data_t *, DNNL_ARG_DIFF_DST);
107 auto ws = CTX_IN_MEM(data_t *, DNNL_ARG_WORKSPACE);
108
109 const auto ker = ker_.get();
110 parallel_nd(N_, H_ * W_, [&](dim_t n, dim_t pixel_id) {
111 typename lrn::jit_avx512_common_lrn_kernel_bwd_nhwc_t<
112 d_type>::jit_args_bwd_t args;
113 const auto offset = n * C_ * H_ * W_ + pixel_id * C_;
114 const auto ws_offset0 = offset * 2;
115 const auto ws_offset1 = ws_offset0 + C_;
116
117 args.src = &src[offset];
118 args.diff_dst = &diff_dst[offset];
119 args.ws0 = &ws[ws_offset0];
120 args.ws1 = &ws[ws_offset1];
121 args.diff_src = &diff_src[offset];
122
123 (*ker)(&args);
124 });
125
126 return status::success;
127 }
128
129 virtual ~lrn_avx512_nhwc_executor_bwd_t() = default;
130
131private:
132 std::unique_ptr<jit_avx512_common_lrn_kernel_bwd_nhwc_t<d_type>> ker_;
133 const int N_;
134 const int C_;
135 const int H_;
136 const int W_;
137};
138
139} // namespace lrn
140} // namespace x64
141} // namespace cpu
142} // namespace impl
143} // namespace dnnl
144
145#endif
146