1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP
18#define CPU_X64_PRELU_JIT_PRELU_FORWARD_KERNEL_HPP
19
20#include <map>
21
22#include "cpu/cpu_prelu_pd.hpp"
23#include "cpu/x64/cpu_isa_traits.hpp"
24#include "cpu/x64/prelu/jit_prelu_base_kernel.hpp"
25#include "cpu/x64/utils/jit_io_helper.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32class jit_prelu_forward_kernel_t : public jit_prelu_base_kernel_t {
33public:
34 static jit_prelu_forward_kernel_t *create(const cpu_prelu_fwd_pd_t *pd);
35
36 struct call_params_t {
37 const void *src = nullptr, *weights = nullptr, *dst = nullptr;
38 size_t compute_data_size = 0u;
39 };
40
41 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_forward_kernel_t)
42
43 void operator()(jit_prelu_forward_kernel_t::call_params_t *params) {
44 jit_generator::operator()(params);
45 }
46
47protected:
48 const data_type_t src_dt_;
49 const data_type_t wei_dt_;
50 const data_type_t dst_dt_;
51 const size_t dst_tail_block_;
52
53 jit_prelu_forward_kernel_t(const cpu_prelu_fwd_pd_t *pd,
54 const cpu_isa_t &isa, const int vlen,
55 const size_t number_vmm_single_compute);
56 Xbyak::Address data_ptr(int arg_num, size_t offt = 0);
57
58private:
59 bool any_tensor_bf16() const override;
60 void load_kernel_call_params() override;
61 void finalize() override {}
62
63protected:
64 const Xbyak::Reg64 &reg_src_ = r10;
65 const Xbyak::Reg64 &reg_dst_ = r11;
66 const Xbyak::Reg64 &reg_weights_ = r12;
67 const cpu_prelu_fwd_pd_t *pd_;
68};
69
70template <typename Vmm>
71class jit_uni_prelu_forward_kernel_t : public jit_prelu_forward_kernel_t {
72public:
73 jit_uni_prelu_forward_kernel_t(
74 const cpu_prelu_fwd_pd_t *pd, const cpu_isa_t &isa);
75 ~jit_uni_prelu_forward_kernel_t() override;
76
77private:
78 using jit_generator::uni_vfmadd132ps;
79
80 void prepare_kernel_const_vars() override;
81 void compute_dst(size_t unrolling_factor, bool tail) override;
82 bool can_load_wei_from_addr_directly(bool tail) const noexcept;
83
84 Vmm get_or_load_weights(
85 const Xbyak::Address &src_addr, const Vmm &dst_vmm, bool tail);
86 void uni_vfmadd132ps(
87 const Vmm &x1, const Vmm &x2, const Xbyak::Operand &op, bool tail);
88 std::map<data_type_t, io::io_saturation_conf_t>
89 create_saturation_vmm_map() const;
90
91 const bool saturation_needed_ = false;
92 const Vmm tail_vmm_mask_; // Keep it higher to preserve idx=0 tail register
93 const Vmm vmm_zeros_;
94 const Vmm dst_saturate_ubound_;
95 const Vmm weights_const_vmm_;
96 const size_t number_vmm_single_compute_ = 0;
97 const Xbyak::Opmask &tail_opmask_ = k1;
98 const Xbyak::Reg64 &reg_tmp_ = r15;
99
100 io::jit_io_multi_dt_helper_t<Vmm> io_;
101};
102
103} // namespace x64
104} // namespace cpu
105} // namespace impl
106} // namespace dnnl
107
108#endif
109