1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_NHWC_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_NHWC_HPP
19
20#include "cpu/x64/lrn/jit_avx512_common_lrn_bwd_base.hpp"
21#include "cpu/x64/lrn/jit_avx512_common_lrn_utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27namespace lrn {
28
29using namespace dnnl::impl::status;
30using namespace dnnl::impl::utils;
31using namespace data_type;
32using namespace Xbyak;
33using namespace Xbyak::util;
34
35template <data_type_t d_type>
36class jit_avx512_common_lrn_kernel_bwd_nhwc_t
37 : public jit_avx512_common_lrn_kernel_bwd_t<d_type> {
38public:
39 jit_avx512_common_lrn_kernel_bwd_nhwc_t(unsigned C, float alpha, float beta,
40 int local_size, void *code_ptr = nullptr,
41 size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
42
43 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_bwd_nhwc_t)
44
45private:
46 void generate() override;
47 void set_up_ker_params();
48 void execute_compute_loop(unsigned num_full_16c_blocks, unsigned C_tail);
49 void compute_loop(across_version version, tail_mode tail_proc,
50 unsigned C_tail = 0, int loop_size_param = 1);
51 void compute(int loop_size_param, tail_mode tail_proc);
52 void increment_loop_params(std::size_t offset);
53 void load_compute_data(
54 across_version version, tail_mode tail_proc, int loop_size_param);
55 void store_compute_data(
56 int loop_size_param, tail_mode tail_m, unsigned C_tail);
57 void reserve_stack_space(std::size_t space);
58 void unreserve_stack_space(std::size_t space);
59 void load_data_to_stack(
60 unsigned C_tail, across_version version, tail_mode tail_proc);
61 int get_stack_offset(const Reg64 reg, tail_mode tail_proc);
62
63 const std::vector<int> tmp_mask_prev_;
64 const std::vector<int> tmp_mask_next_;
65
66 static constexpr int zmm_size_ = 64;
67 static constexpr int tmp_load_to_stack_idx_prev_ = 12;
68 static constexpr int tmp_load_to_stack_idx_tail_ = 13;
69 static constexpr int tmp_store_from_stack_idx_tail_ = 14;
70
71 const Reg64 mask_ = r11;
72 const Reg64 blockC_ = r12;
73
74 const int half_ls_;
75 unsigned C_;
76};
77
78} // namespace lrn
79} // namespace x64
80} // namespace cpu
81} // namespace impl
82} // namespace dnnl
83
84#endif
85