1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_REDUCTION_HPP
18#define CPU_X64_PRELU_JIT_PRELU_REDUCTION_HPP
19
20#include <memory>
21
22#include "cpu/cpu_prelu_pd.hpp"
23#include "cpu/x64/cpu_isa_traits.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/utils/jit_io_helper.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32class jit_prelu_reduction_kernel_t : public jit_generator {
33public:
34 static jit_prelu_reduction_kernel_t *create(const cpu_prelu_bwd_pd_t *pd);
35
36 struct call_params_t {
37 size_t reduction_blocks = 0;
38 const void *weights_diff_scratch = nullptr;
39 void *weights_diff = nullptr;
40 bool tail = false;
41 bool is_last_c_blk = false;
42 };
43
44 void generate() override;
45 size_t simd_w() const;
46 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_prelu_reduction_kernel_t)
47
48 void operator()(jit_prelu_reduction_kernel_t::call_params_t *params) {
49 jit_generator::operator()(params);
50 }
51
52private:
53 void load_kernel_call_params();
54 virtual size_t get_unrolling_factor(bool tail) const = 0;
55 virtual void compute_dst(int unrolling_factor, bool tail) = 0;
56 virtual void prepare_kernel_const_vars(bool tail) = 0;
57 virtual void finalize(bool tail) = 0;
58 void generate(bool tail);
59
60 const Xbyak::Reg64 &reg_reduction_blocks_ = r8;
61 const Xbyak::Reg64 &reg_weights_diff_scratch_ = r10;
62 const Xbyak::Reg8 &reg_tail_ = r12b;
63
64 const size_t scratchpad_c_block_offset_ = 0;
65
66protected:
67 jit_prelu_reduction_kernel_t(const cpu_prelu_bwd_pd_t *pd, int simd_w);
68 Xbyak::Address diff_scratch_ptr(int unrolling_group) const;
69 int reserve_vmm();
70
71 const size_t simd_w_ = 0;
72 const data_type_t data_type_;
73 const size_t tail_size_ = 0;
74 const Xbyak::Reg64 &reg_offset_ = r9;
75 const Xbyak::Reg64 &reg_weights_diff_ = r11;
76 const Xbyak::Reg8 &reg_last_c_blk_byte_ = r13b;
77 size_t number_reserved_vmms_ = 0;
78 size_t tail_block_size_ = 0;
79 size_t c_blk_nelems_ = 0;
80};
81
82template <typename Vmm>
83class jit_uni_prelu_reduction_kernel_t : public jit_prelu_reduction_kernel_t {
84public:
85 jit_uni_prelu_reduction_kernel_t(
86 const cpu_prelu_bwd_pd_t *pd, const cpu_isa_t &isa);
87
88private:
89 size_t get_unrolling_factor(bool tail) const override;
90 void prepare_kernel_const_vars(bool tail) override;
91 void finalize(bool tail) override;
92 void compute_dst(int unrolling_factor, bool tail) override;
93
94 const cpu_isa_t isa_;
95 const bool saturation_needed_;
96 const Vmm tail_vmm_mask_; // Keep it higher to preserve idx=0 tail register
97 const Vmm accumulator_;
98 const Vmm saturation_lower_bound_;
99 const Vmm saturation_upper_bound_;
100
101 const Xbyak::Opmask &tail_opmask_ = k1;
102 const Xbyak::Reg64 &reg_tmp_ = r15;
103
104 io::jit_io_helper_t<Vmm> io_;
105};
106
107} // namespace x64
108} // namespace cpu
109} // namespace impl
110} // namespace dnnl
111
112#endif
113