1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef CPU_X64_JIT_TRANSPOSE_UTILS_HPP |
18 | #define CPU_X64_JIT_TRANSPOSE_UTILS_HPP |
19 | |
20 | #include "cpu/x64/cpu_barrier.hpp" |
21 | #include "cpu/x64/jit_primitive_conf.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | namespace x64 { |
27 | |
28 | struct jit_trans_src_t { |
29 | struct ctx_t { |
30 | const void *src; |
31 | const void *tr_src; |
32 | const void *src_prf; |
33 | const void *tr_src_prf; |
34 | int ch_work; |
35 | }; |
36 | |
37 | virtual void operator()(ctx_t *ctx) = 0; |
38 | virtual status_t create_kernel() = 0; |
39 | |
40 | jit_trans_src_t(const jit_conv_conf_t *conf) : conf_(conf) {} |
41 | virtual ~jit_trans_src_t() {} |
42 | |
43 | const jit_conv_conf_t *conf_; |
44 | }; |
45 | |
46 | struct jit_src_transpose_s { |
47 | size_t size; |
48 | const void *src; |
49 | const void *tr_src; |
50 | const void *src_prf; |
51 | const void *tr_src_prf; |
52 | }; |
53 | |
54 | struct jit_trans_dst_t { |
55 | struct ctx_t { |
56 | const void *src; |
57 | const void *tr_src; |
58 | const void *src_prf; |
59 | const void *tr_src_prf; |
60 | int ch_work; |
61 | }; |
62 | |
63 | jit_trans_dst_t(const jit_conv_conf_t *conf) : conf_(conf) {} |
64 | virtual ~jit_trans_dst_t() {} |
65 | |
66 | virtual void operator()(ctx_t *ctx) = 0; |
67 | virtual status_t create_kernel() = 0; |
68 | const jit_conv_conf_t *conf_; |
69 | }; |
70 | |
71 | struct jit_transpose4x16_src_t { |
72 | int src_pf0_distance; |
73 | int tr_src_pf0_distance; |
74 | bool src_pf1; |
75 | bool tr_src_pf1; |
76 | }; |
77 | |
78 | struct jit_transpose4x16_src : public jit_generator { |
79 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src) |
80 | |
81 | jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams, |
82 | jit_transpose4x16_src_t *tparams_) |
83 | : jit_generator(jit_name()), params(aparams), tparams(tparams_) {} |
84 | |
85 | const jit_1x1_conv_conf_t *params; |
86 | const jit_transpose4x16_src_t *tparams; |
87 | |
88 | static const int transpose_size = 4; |
89 | |
90 | private: |
91 | static const int typesize = sizeof(float); |
92 | |
93 | int src_stride = 0, tr_src_stride = 0; |
94 | |
95 | Xbyak::Reg64 imm_addr64 = rbx; |
96 | |
97 | Xbyak::Opmask kF0 = k1; |
98 | Xbyak::Opmask kCC = k2; |
99 | Xbyak::Opmask k33 = k3; |
100 | Xbyak::Opmask kFFFF = k4; |
101 | |
102 | Xbyak::Zmm vidx01 = zmm31; |
103 | Xbyak::Zmm vidx10 = zmm30; |
104 | Xbyak::Zmm vidx1 = zmm29; |
105 | Xbyak::Zmm vidxP = zmm28; |
106 | |
107 | Xbyak::Reg64 reg_src = r8; |
108 | Xbyak::Reg64 reg_tr_src = r9; |
109 | Xbyak::Reg64 reg_src_prf = r10; |
110 | Xbyak::Reg64 reg_tr_src_prf = r11; |
111 | Xbyak::Reg64 reg_loop = r12; |
112 | Xbyak::Reg64 reg_tr_src_tmp = r13; |
113 | Xbyak::Reg32 regw_tmp = r14d; |
114 | |
115 | void transpose_block(int ur, int nrows); |
116 | void transpose(int nrows); |
117 | void generate() override; |
118 | }; |
119 | |
120 | struct jit_diff_wei_trans_to_vnni_t : public jit_generator { |
121 | DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_diff_wei_trans_to_vnni_t) |
122 | |
123 | jit_diff_wei_trans_to_vnni_t(const data_type_t dt, const int &kd, |
124 | const int &kh, const int &kw, const int &ic_block, |
125 | const int &oc_block) |
126 | : jit_generator(jit_name()) |
127 | , out_dt_(dt) |
128 | , kd_(kd) |
129 | , kh_(kh) |
130 | , kw_(kw) |
131 | , ic_block_(ic_block) |
132 | , oc_block_(oc_block) {} |
133 | |
134 | ~jit_diff_wei_trans_to_vnni_t() {} |
135 | |
136 | status_t create_kernel() override { return jit_generator::create_kernel(); } |
137 | |
138 | const data_type_t out_dt_; |
139 | const int kd_, kh_, kw_; |
140 | const int ic_block_, oc_block_; |
141 | |
142 | private: |
143 | void generate() override; |
144 | }; |
145 | |
146 | jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf); |
147 | jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf); |
148 | |
149 | } // namespace x64 |
150 | } // namespace cpu |
151 | } // namespace impl |
152 | } // namespace dnnl |
153 | |
154 | #endif |
155 | |