1 | /******************************************************************************* |
2 | * Copyright 2020-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_ |
18 | #define CPU_X64_PRELU_JIT_PRELU_BASE_KERNEL_HPP_ |
19 | |
20 | #include "cpu/x64/jit_generator.hpp" |
21 | #include "cpu/x64/prelu/jit_prelu_utils.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | class jit_prelu_base_kernel_t : public jit_generator { |
29 | public: |
30 | jit_prelu_base_kernel_t(const cpu_isa_t &isa, const int vlen, |
31 | const prelu::bcast &bcast, const memory_desc_wrapper &tensor_md, |
32 | const size_t number_vmm_single_compute, const char *name); |
33 | |
34 | size_t simd_w() const noexcept; |
35 | prelu::bcast get_bcast() const noexcept; |
36 | |
37 | protected: |
38 | int reserve_vmm(); |
39 | int get_compute_vmm(size_t base_idx, size_t unroll_group) const; |
40 | |
41 | size_t get_number_reserved_vmms() const noexcept; |
42 | |
43 | const cpu_isa_t isa_; |
44 | const size_t simd_w_ = 0; |
45 | const prelu::bcast bcast_ = prelu::bcast::unsupported; |
46 | const size_t tail_size_ = 0u; |
47 | const Xbyak::Reg64 ®_data_size_ = r8; |
48 | const Xbyak::Reg64 ®_offset_ = r9; |
49 | |
50 | private: |
51 | void generate() override; |
52 | virtual bool any_tensor_bf16() const = 0; |
53 | virtual void load_kernel_call_params() = 0; |
54 | virtual void prepare_kernel_const_vars() = 0; |
55 | virtual void compute_dst(size_t unrolling_factor, bool tail) = 0; |
56 | virtual void finalize() = 0; |
57 | size_t calc_unrolling_factor() const noexcept; |
58 | size_t calc_tail_size(const memory_desc_wrapper &tensor_md) const noexcept; |
59 | const memory_desc_wrapper tensor_md_; |
60 | const size_t number_vmm_single_compute_ = 0; |
61 | size_t number_reserved_vmms_ = 0; |
62 | }; |
63 | |
64 | } // namespace x64 |
65 | } // namespace cpu |
66 | } // namespace impl |
67 | } // namespace dnnl |
68 | |
69 | #endif |
70 | |