1/*******************************************************************************
2* Copyright 2020 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_FWD_NHWC_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_FWD_NHWC_HPP
19
20#include "cpu/x64/lrn/jit_avx512_common_lrn_fwd_base.hpp"
21#include "cpu/x64/lrn/jit_avx512_common_lrn_utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27namespace lrn {
28
29template <data_type_t d_type>
30class jit_avx512_common_lrn_kernel_fwd_nhwc_t
31 : public jit_avx512_common_lrn_kernel_fwd_t<d_type> {
32public:
33 jit_avx512_common_lrn_kernel_fwd_nhwc_t(unsigned C, prop_kind_t prop_kind,
34 float alpha, float beta, float k, int local_size,
35 void *code_ptr = nullptr,
36 size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE);
37
38 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_fwd_nhwc_t);
39
40private:
41 void generate() override;
42 void set_up_ker_params();
43 void execute_compute_loop(unsigned num_full_16c_blocks, unsigned C_tail);
44 void compute_loop(across_version version, tail_mode tail_mode,
45 unsigned C_tail = 0, int loop_size_param = 1);
46 void compute(int loop_size_param);
47 void increment_loop_params(std::size_t offset);
48 void load_compute_data(
49 across_version version, tail_mode tail_proc, int loop_size_param);
50 void store_compute_data(
51 int loop_size_param, tail_mode tail_mode, unsigned C_tail);
52 void reserve_stack_space(std::size_t space);
53 void unreserve_stack_space(std::size_t space);
54 void load_data_to_stack(
55 unsigned C_tail, across_version version, tail_mode tail_mode);
56
57 const std::vector<int> tmp_mask_prev_;
58 const std::vector<int> tmp_mask_next_;
59 static constexpr int tmp_load_to_stack_idx_prev_ = 12;
60 static constexpr int tmp_load_to_stack_idx_tail_ = 13;
61 static constexpr int tmp_store_from_stack_idx_tail_ = 14;
62
63 static constexpr int zmm_size = 64;
64 const Reg64 mask_ = r10;
65 const Reg64 blockC_ = r9;
66
67 const int half_ls_;
68 unsigned C;
69};
70
71} // namespace lrn
72} // namespace x64
73} // namespace cpu
74} // namespace impl
75} // namespace dnnl
76
77#endif
78