1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_LRN_JIT_AVX512_COMMON_LRN_FWD_BASE_HPP
18#define CPU_X64_LRN_JIT_AVX512_COMMON_LRN_FWD_BASE_HPP
19
20#include <functional>
21#include <memory>
22#include "common/c_types_map.hpp"
23#include "cpu/x64/jit_avx512_core_bf16cvt.hpp"
24#include "cpu/x64/jit_generator.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30namespace lrn {
31
32using acc_data_t = float;
33using acc_data_bf16_t = uint16_t;
34
35using namespace dnnl::impl::status;
36using namespace dnnl::impl::utils;
37using namespace data_type;
38using namespace Xbyak;
39using namespace Xbyak::util;
40
41template <data_type_t d_type>
42class jit_avx512_common_lrn_kernel_fwd_t : public jit_generator {
43public:
44 jit_avx512_common_lrn_kernel_fwd_t(prop_kind_t prop_kind, float alpha,
45 float beta, float k, int local_size, void *code_ptr,
46 size_t code_size, const char *name = jit_name());
47
48 using data_t = typename prec_traits<d_type>::type;
49
50 struct jit_args_fwd_t {
51 jit_args_fwd_t();
52 const data_t *src;
53 data_t *dst, *ws0, *ws1;
54 static const int32_t mask[48];
55 const int32_t *mask_ptr;
56 };
57
58 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_fwd_t);
59
60protected:
61 Zmm zreg(int irb, int i) const;
62 Ymm yreg(int irb, int i) const;
63 Xmm xreg(int irb, int i) const;
64
65 void store_data(const Address addr, Zmm zr, Ymm yr);
66 void load_tail(int tail_value, Reg64 src, int src_mem_offset,
67 int dst_stack_offset, int tmp_load_to_stack_idx_tail);
68 void load_data(Xmm reg, const Address p, bool from_stack = false);
69 void store_tail(int tail_value, Zmm src, Reg64 dst, int dst_mem_offset,
70 int tmp_stack_offset, int tmp_idx);
71
72 prop_kind_t pk_;
73 float alpha_, beta_, k_;
74 static constexpr int xmm_size_ = 4 * sizeof(acc_data_t);
75 static constexpr int zmm_size_ = 64;
76 const Reg64 imm_addr64_ = rbx;
77 const Reg16 imm_addr16_ = bx;
78
79 const Xmm xalpha_ = xmm0;
80 const Zmm zalpha_ = zmm0;
81 const Zmm zk_ = zmm1;
82 const Xmm xk_ = xmm1;
83 const Reg64 src_ = rax;
84 const Reg64 dst_ = r8;
85 const Reg64 ws0_ = rdx;
86 const Reg64 ws1_ = rsi;
87 const Reg64 param_ = abi_param1;
88
89 const int local_size_;
90
91 static constexpr int zc_ = 2;
92 const std::vector<int> z_prev_;
93 const std::vector<int> z_next_;
94
95 const int zsum_;
96 static constexpr int zsrc_ = 2;
97 static constexpr int zdst_ = 3;
98 static constexpr int zsum2_ = 5;
99 static constexpr int zbase_ = 4;
100
101 const Zmm bf16_emu_reserv_1_ = zmm28;
102 const Zmm bf16_emu_reserv_2_ = zmm29;
103 const Reg64 bf16_emu_scratch_ = rax;
104 const Zmm bf16_emu_reserv_3_ = zmm30;
105 const Zmm bf16_emu_reserv_4_ = zmm31;
106
107 const bool emulateBfloat_;
108 const int regs_used_per_block_;
109 const int reg_block_;
110 static constexpr int vlen_ = utils::one_of(d_type, bf16, f16) ? 32 : 64;
111 std::unique_ptr<bf16_emulation_t> bf16_emu_ = nullptr;
112};
113
114} // namespace lrn
115} // namespace x64
116} // namespace cpu
117} // namespace impl
118} // namespace dnnl
119
120#endif
121