1 | /******************************************************************************* |
2 | * Copyright 2017-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP |
18 | #define CPU_X64_JIT_SSE41_CONV_KERNEL_F32_HPP |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/memory.hpp" |
22 | |
23 | #include "cpu/x64/injectors/jit_uni_postops_injector.hpp" |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/jit_primitive_conf.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | struct jit_sse41_conv_fwd_kernel_f32 : public jit_generator { |
33 | jit_sse41_conv_fwd_kernel_f32(const jit_conv_conf_t &ajcp, |
34 | const primitive_attr_t &attr, const memory_desc_t &dst_md); |
35 | |
36 | static status_t init_conf(jit_conv_conf_t &jcp, |
37 | const convolution_desc_t &cd, const memory_desc_wrapper &src_d, |
38 | const memory_desc_wrapper &weights_d, |
39 | const memory_desc_wrapper &dst_d, const primitive_attr_t &attr, |
40 | int nthreads); |
41 | |
42 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_conv_fwd_kernel_f32) |
43 | jit_conv_conf_t jcp; |
44 | const primitive_attr_t &attr_; |
45 | |
46 | private: |
47 | static constexpr auto simd_w_ = cpu_isa_traits<sse41>::vlen / sizeof(float); |
48 | using reg64_t = const Xbyak::Reg64; |
49 | reg64_t reg_input = rax; |
50 | reg64_t aux_reg_input = r8; |
51 | reg64_t reg_kernel = rdx; |
52 | reg64_t aux_reg_kernel = r9; |
53 | reg64_t reg_output = rsi; |
54 | reg64_t reg_bias = rbx; |
55 | |
56 | reg64_t kj = r10; |
57 | reg64_t oi_iter = r11; |
58 | reg64_t ki_iter = r12; |
59 | reg64_t reg_kh = abi_not_param1; |
60 | reg64_t simd_iter = r15; |
61 | reg64_t reg_oc_blocks = r14; |
62 | reg64_t imm_addr64 = reg_oc_blocks; |
63 | |
64 | Xbyak::Reg32 reg_ci_flag = r13d; |
65 | |
66 | std::unique_ptr<injector::jit_uni_postops_injector_t<sse41>> |
67 | postops_injector_; |
68 | |
69 | inline void oh_step_unroll_kw( |
70 | int ur_w, int pad_l, int pad_r, int oc_blocks); |
71 | inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks); |
72 | inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks); |
73 | inline void solve_common(int oc_blocks); |
74 | |
75 | inline dim_t filter_w_to_input(int ki, int oi = 0, int pad_l = 0) { |
76 | return ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l; |
77 | } |
78 | |
79 | inline dim_t filter_h_to_input(int ki) { |
80 | return ki * (jcp.dilate_h + 1) * jcp.iw; |
81 | } |
82 | |
83 | inline dim_t get_input_offset(int i_ic, int i_iw) { |
84 | dim_t offset; |
85 | if (utils::one_of(jcp.src_tag, format_tag::ncw, format_tag::nchw, |
86 | format_tag::ncdhw)) { |
87 | offset = i_ic * jcp.ih * jcp.iw + i_iw; |
88 | } else if (utils::one_of(jcp.src_tag, format_tag::nwc, format_tag::nhwc, |
89 | format_tag::ndhwc)) { |
90 | offset = i_iw * jcp.ic * jcp.ngroups + i_ic; |
91 | } else { |
92 | offset = i_iw * jcp.ic_block + i_ic; |
93 | } |
94 | return sizeof(float) * offset; |
95 | } |
96 | |
97 | inline dim_t get_output_offset(int i_oc_block, int i_ow) { |
98 | dim_t offset; |
99 | if (utils::one_of(jcp.dst_tag, format_tag::nwc, format_tag::nhwc, |
100 | format_tag::ndhwc)) { |
101 | offset = i_ow * jcp.oc * jcp.ngroups + i_oc_block * jcp.oc_block; |
102 | } else { |
103 | offset = (i_oc_block * jcp.oh * jcp.ow + i_ow) * jcp.oc_block; |
104 | } |
105 | return sizeof(float) * offset; |
106 | } |
107 | |
108 | inline dim_t get_kernel_offset(int i_oc_block, int ki, int i_ic) { |
109 | dim_t block_step_size = jcp.ic_block * jcp.oc_block; |
110 | dim_t ic_block_step_size = jcp.kh * jcp.kw * block_step_size; |
111 | dim_t oc_block_step_size = jcp.nb_ic * ic_block_step_size; |
112 | dim_t offset = i_oc_block * oc_block_step_size + ki * block_step_size |
113 | + i_ic * jcp.oc_block; |
114 | return sizeof(float) * offset; |
115 | } |
116 | |
117 | void apply_postops(const int oc_blocks, const int ur_w); |
118 | |
119 | void generate() override; |
120 | }; |
121 | |
122 | } // namespace x64 |
123 | } // namespace cpu |
124 | } // namespace impl |
125 | } // namespace dnnl |
126 | |
127 | #endif |
128 | |