1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_BACKWARD_KERNEL_HPP
18#define CPU_X64_PRELU_JIT_PRELU_BACKWARD_KERNEL_HPP
19
20#include <map>
21#include <utility>
22
23#include "cpu/cpu_prelu_pd.hpp"
24#include "cpu/x64/cpu_isa_traits.hpp"
25#include "cpu/x64/prelu/jit_prelu_base_kernel.hpp"
26#include "cpu/x64/utils/jit_io_helper.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33class jit_prelu_backward_kernel_t : public jit_prelu_base_kernel_t {
34public:
35 static jit_prelu_backward_kernel_t *create(const cpu_prelu_bwd_pd_t *pd);
36
37 struct call_params_t {
38 const void *src = nullptr, *weights = nullptr, *dst_diff = nullptr;
39 void *src_diff = nullptr, *weights_diff = nullptr;
40 size_t compute_data_size = 0u;
41 };
42
43 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_backward_kernel_t)
44
45 void operator()(jit_prelu_backward_kernel_t::call_params_t *params) {
46 jit_generator::operator()(params);
47 }
48
49protected:
50 jit_prelu_backward_kernel_t(const cpu_prelu_bwd_pd_t *pd,
51 const cpu_isa_t &isa, const int vlen,
52 size_t number_vmm_single_compute);
53 Xbyak::Address data_ptr(int arg_num, size_t offt = 0);
54
55 const cpu_prelu_bwd_pd_t *pd_;
56 const Xbyak::Reg64 &reg_weights_ = r10;
57 const Xbyak::Reg64 &reg_weights_diff_ = r11;
58
59 const data_type_t src_dt_;
60 const data_type_t wei_dt_;
61 const data_type_t diff_src_dt_;
62 const data_type_t diff_dst_dt_;
63 const data_type_t diff_wei_dt_;
64 const size_t diff_src_block_tail_;
65 const size_t diff_wei_block_tail_;
66
67 const Xbyak::Reg64 &reg_src_ = r12;
68 const Xbyak::Reg64 &reg_src_diff_ = r13;
69 const Xbyak::Reg64 &reg_dst_diff_ = r14;
70
71private:
72 bool any_tensor_bf16() const override;
73 void load_kernel_call_params() override;
74};
75
76template <typename Vmm>
77class jit_uni_prelu_backward_kernel_t : public jit_prelu_backward_kernel_t {
78public:
79 jit_uni_prelu_backward_kernel_t(
80 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa);
81 ~jit_uni_prelu_backward_kernel_t() override;
82
83private:
84 void prepare_kernel_const_vars() override;
85 void compute_dst(size_t unrolling_factor, bool tail) override;
86 const Xbyak::Operand &get_or_load_weights(
87 const Xbyak::Address &src_addr, const Vmm &dst_vmm, bool tail);
88 void accumulate_weights_diff(const Vmm &partial_sum_vmm, const Vmm &tmp_vmm,
89 const Xbyak::Address &dst_addr, bool tail);
90 void finalize() override;
91 std::map<data_type_t, io::io_saturation_conf_t>
92 create_saturation_vmm_map() const;
93
94 const bool saturation_needed_diff_src_;
95 const bool saturation_needed_diff_weights_;
96
97 const Vmm tail_vmm_mask_; // Keep it higher to preserve idx=0 tail register
98 const Vmm vmm_zeros_;
99 const Vmm saturation_ubound_diff_src_;
100 const Vmm saturation_ubound_diff_weights_;
101
102 const Vmm vmm_ones_;
103 const Vmm weights_const_vmm_;
104 const Vmm weights_diff_acc_vmm_;
105
106 const Xbyak::Opmask &tail_opmask_ = k1;
107 const Xbyak::Reg64 &reg_tmp_ = r15;
108
109 io::jit_io_multi_dt_helper_t<Vmm> io_;
110};
111
112} // namespace x64
113} // namespace cpu
114} // namespace impl
115} // namespace dnnl
116
117#endif
118