1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_BASE_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_BWD_BASE_HPP
19
20#include <functional>
21#include <memory>
22#include "common/c_types_map.hpp"
23#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
24#include "cpu/x64/jit_generator.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30namespace lrn {
31
32using acc_data_t = float;
33using acc_data_bf16_t = uint16_t;
34
35using namespace dnnl::impl::status;
36using namespace dnnl::impl::utils;
37using namespace data_type;
38using namespace Xbyak;
39using namespace Xbyak::util;
40
41template <data_type_t d_type>
42class jit_avx512_common_lrn_kernel_bwd_t : public jit_generator {
43public:
44 jit_avx512_common_lrn_kernel_bwd_t(float alpha, float beta, int local_size,
45 void *code_ptr, size_t code_size, const char *name = jit_name());
46
47 using data_t = typename prec_traits<d_type>::type;
48
49 struct jit_args_bwd_t {
50 jit_args_bwd_t();
51 const data_t *src, *diff_dst, *ws0, *ws1;
52 data_t *diff_src;
53 static const int32_t mask[48];
54 const int32_t *mask_ptr;
55 };
56
57 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_bwd_t);
58
59protected:
60 Zmm zreg(int irb, int i) const;
61 Ymm yreg(int irb, int i) const;
62 Xmm xreg(int irb, int i) const;
63
64 void store_data(bool non_temp_hint, const Address addr, Zmm zr);
65 void load_data(Xmm reg, const Address p, bool from_stack = false);
66 void load_tail(int tail_value, Reg64 src, int src_mem_offset,
67 int dst_stack_offset, int tmp_load_to_stack_idx_tail);
68 void store_tail(int tail_value, Zmm src, Reg64 dst, int dst_mem_offset,
69 int tmp_stack_offset, int tmp_idx);
70
71 const Reg64 src_ = rax;
72 const Reg64 diffsrc_ = r8;
73 const Reg64 diffdst_ = r9;
74 const Reg64 workspace0_ = rdx;
75 const Reg64 workspace1_ = rsi;
76 const Reg64 imm_addr64_ = rbx;
77 const Reg64 param_ = abi_param1;
78 const Reg16 imm_addr16_ = bx;
79 const Zmm znalphabeta_ = zmm0;
80 const Xmm xnalphabeta_ = xmm0;
81
82 const Zmm bf16_emu_reserv_1_ = Zmm(28);
83 const Zmm bf16_emu_reserv_2_ = Zmm(29);
84 const Reg64 bf16_emu_scratch_ = rax;
85 const Zmm bf16_emu_reserv_3_ = Zmm(30);
86 const Zmm bf16_emu_reserv_4_ = Zmm(31);
87 const int local_size_;
88
89 static constexpr int z_tmp_ = 7;
90
91 static constexpr int zdiffdst_ = 1;
92 static constexpr int zdiffsrc_ = 2;
93 static constexpr int zsrc_ = 3;
94
95 const std::vector<int> z_prev_;
96 const std::vector<int> z_next_;
97
98 static constexpr int zws0_ = 4;
99
100 float nalphabeta_;
101 const bool emulateBfloat_;
102 const int regs_used_per_block_;
103 const int reg_block_;
104 static constexpr int vlen_ = utils::one_of(d_type, bf16, f16) ? 32 : 64;
105 std::unique_ptr<bf16_emulation_t> bf16_emu_;
106};
107
108} // namespace lrn
109} // namespace x64
110} // namespace cpu
111} // namespace impl
112} // namespace dnnl
113
114#endif
115