1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_
18#define CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_
19
20#include "cpu/x64/jit_generator.hpp"
21#include "cpu/x64/prelu/jit_prelu_utils.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28class jit_prelu_base_kernel_t : public jit_generator {
29public:
30 jit_prelu_base_kernel_t(const cpu_isa_t &isa, const int vlen,
31 const prelu::bcast &bcast, const memory_desc_wrapper &tensor_md,
32 const size_t number_vmm_single_compute, const char *name);
33
34 size_t simd_w() const noexcept;
35 prelu::bcast get_bcast() const noexcept;
36
37protected:
38 int reserve_vmm();
39 int get_compute_vmm(size_t base_idx, size_t unroll_group) const;
40
41 size_t get_number_reserved_vmms() const noexcept;
42
43 const cpu_isa_t isa_;
44 const size_t simd_w_ = 0;
45 const prelu::bcast bcast_ = prelu::bcast::unsupported;
46 const size_t tail_size_ = 0u;
47 const Xbyak::Reg64 &reg_data_size_ = r8;
48 const Xbyak::Reg64 &reg_offset_ = r9;
49
50private:
51 void generate() override;
52 virtual bool any_tensor_bf16() const = 0;
53 virtual void load_kernel_call_params() = 0;
54 virtual void prepare_kernel_const_vars() = 0;
55 virtual void compute_dst(size_t unrolling_factor, bool tail) = 0;
56 virtual void finalize() = 0;
57 size_t calc_unrolling_factor() const noexcept;
58 size_t calc_tail_size(const memory_desc_wrapper &tensor_md) const noexcept;
59 const memory_desc_wrapper tensor_md_;
60 const size_t number_vmm_single_compute_ = 0;
61 size_t number_reserved_vmms_ = 0;
62};
63
64} // namespace x64
65} // namespace cpu
66} // namespace impl
67} // namespace dnnl
68
69#endif
70