1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_BLOCKED_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_BLOCKED_HPP
19
20#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_base.hpp"
21#include "cpu/x64/lrn/jit_avx512_common_lrn_utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27namespace lrn {
28
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::utils;
31using namespace data_type;
32using namespace Xbyak;
33using namespace Xbyak::util;
34
35template <data_type_t d_type>
36class jit_avx512_common_lrn_kernel_bwd_blocked_t
37 : public jit_avx512_common_lrn_kernel_bwd_t<d_type> {
38public:
39 using data_t = typename prec_traits<d_type>::type;
40
41 struct jit_args_bwd_t {
42 const data_t *src, *diff_dst, *ws0, *ws1;
43 data_t *diff_src;
44 };
45
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_bwd_blocked_t)
47
48 jit_avx512_common_lrn_kernel_bwd_blocked_t(const struct nChw16c_across_t &J,
49 float alpha, float beta, int local_size, int use_h_parallel,
50 void *code_ptr = nullptr,
51 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
52
53private:
54 void generate() override;
55 void compute_loop(int loop_size_param);
56
57 int xmm_size_, zmm_size_, buffer_block_, buffer_nest_offset_,
58 src_prev_offset_;
59 int HW_, W_;
60 across_version version_;
61
62 const Reg64 hw_ = r10;
63
64 const int xws1_prev_ = 3;
65 const int xdiffdst_prev_ = 4;
66 const int zws1_ = 3;
67
68 const int xws1_next_ = 3;
69 const int xdiffdst_next_ = 5;
70
71 int use_h_parallelism_;
72};
73
74} // namespace lrn
75} // namespace x64
76} // namespace cpu
77} // namespace impl
78} // namespace dnnl
79
80#endif
81