1/*******************************************************************************
2* Copyright 2017-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP
18#define CPU_X64_JIT_SSE41_1X1_CONV_KERNEL_F32_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory.hpp"
22
23#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/jit_primitive_conf.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32struct jit_sse41_1x1_conv_kernel_f32 : public jit_generator {
33 jit_sse41_1x1_conv_kernel_f32(const jit_1x1_conv_conf_t &ajcp,
34 const primitive_attr_t &attr, const memory_desc_t &dst_md);
35
36 static status_t init_conf(jit_1x1_conv_conf_t &jcp,
37 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
38 const memory_desc_wrapper &weights_d,
39 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
40 int nthreads);
41
42 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse41_1x1_conv_kernel_f32)
43
44 jit_1x1_conv_conf_t jcp;
45 const primitive_attr_t &attr_;
46
47private:
48 static constexpr auto simd_w_ = cpu_isa_traits<sse41>::vlen / sizeof(float);
49 using reg64_t = const Xbyak::Reg64;
50 using xmm_t = const Xbyak::Xmm;
51
52 reg64_t reg_bcast_data = rax;
53 reg64_t reg_load_data = rsi;
54 reg64_t reg_output_data = rbx;
55 reg64_t aux_reg_bcast_data = rdx;
56 reg64_t aux1_reg_bcast_data = abi_not_param1;
57 reg64_t aux_reg_load_data = abi_param1;
58 reg64_t aux_reg_output_data = rbp;
59 reg64_t reg_load_loop_work = r9;
60 reg64_t reg_bcast_loop_work = r10;
61 reg64_t reg_reduce_loop_work = r11;
62 reg64_t load_loop_iter = r13;
63 reg64_t imm_addr64 = load_loop_iter;
64 reg64_t bcast_loop_iter = r14;
65 reg64_t reduce_loop_iter = r15;
66 reg64_t reg_reduce_pos_flag = r8;
67 reg64_t reg_output_stride = r12;
68 reg64_t reg_bias_data = r12;
69 reg64_t reg_diff_bias_data = bcast_loop_iter;
70
71 constexpr static int reg64_size_ = sizeof(int64_t);
72 constexpr static int reg_diff_bias_data_stack_offt = 0;
73 constexpr static int reg_binary_post_op_acc_off = 1 * reg64_size_;
74 constexpr static int reg_abi_param1_backup = 2 * reg64_size_;
75 constexpr static int stack_space_needed = 3 * reg64_size_;
76
77 xmm_t reg_bcast = xmm_t(15);
78
79 std::unique_ptr<injector::jit_uni_postops_injector_t<sse41>>
80 postops_injector_;
81
82 void generate_bcast_loop(int load_loop_blk);
83 void generate_reduce_loop(int load_loop_blk, int ur);
84 void generate_diff_bias_loop(int load_loop_blk);
85
86 void generate() override;
87
88 void apply_postops(const int load_loop_blk, const int ur);
89 size_t get_fwd_output_ptr_l_off(int i, int j, int n) const;
90};
91
92} // namespace x64
93} // namespace cpu
94} // namespace impl
95} // namespace dnnl
96
97#endif
98