1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_TRANSPOSE_UTILS_HPP
18#define CPU_X64_JIT_TRANSPOSE_UTILS_HPP
19
20#include "cpu/x64/cpu_barrier.hpp"
21#include "cpu/x64/jit_primitive_conf.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26namespace x64 {
27
28struct jit_trans_src_t {
29 struct ctx_t {
30 const void *src;
31 const void *tr_src;
32 const void *src_prf;
33 const void *tr_src_prf;
34 int ch_work;
35 };
36
37 virtual void operator()(ctx_t *ctx) = 0;
38 virtual status_t create_kernel() = 0;
39
40 jit_trans_src_t(const jit_conv_conf_t *conf) : conf_(conf) {}
41 virtual ~jit_trans_src_t() {}
42
43 const jit_conv_conf_t *conf_;
44};
45
46struct jit_src_transpose_s {
47 size_t size;
48 const void *src;
49 const void *tr_src;
50 const void *src_prf;
51 const void *tr_src_prf;
52};
53
54struct jit_trans_dst_t {
55 struct ctx_t {
56 const void *src;
57 const void *tr_src;
58 const void *src_prf;
59 const void *tr_src_prf;
60 int ch_work;
61 };
62
63 jit_trans_dst_t(const jit_conv_conf_t *conf) : conf_(conf) {}
64 virtual ~jit_trans_dst_t() {}
65
66 virtual void operator()(ctx_t *ctx) = 0;
67 virtual status_t create_kernel() = 0;
68 const jit_conv_conf_t *conf_;
69};
70
71struct jit_transpose4x16_src_t {
72 int src_pf0_distance;
73 int tr_src_pf0_distance;
74 bool src_pf1;
75 bool tr_src_pf1;
76};
77
78struct jit_transpose4x16_src : public jit_generator {
79 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src)
80
81 jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams,
82 jit_transpose4x16_src_t *tparams_)
83 : jit_generator(jit_name()), params(aparams), tparams(tparams_) {}
84
85 const jit_1x1_conv_conf_t *params;
86 const jit_transpose4x16_src_t *tparams;
87
88 static const int transpose_size = 4;
89
90private:
91 static const int typesize = sizeof(float);
92
93 int src_stride = 0, tr_src_stride = 0;
94
95 Xbyak::Reg64 imm_addr64 = rbx;
96
97 Xbyak::Opmask kF0 = k1;
98 Xbyak::Opmask kCC = k2;
99 Xbyak::Opmask k33 = k3;
100 Xbyak::Opmask kFFFF = k4;
101
102 Xbyak::Zmm vidx01 = zmm31;
103 Xbyak::Zmm vidx10 = zmm30;
104 Xbyak::Zmm vidx1 = zmm29;
105 Xbyak::Zmm vidxP = zmm28;
106
107 Xbyak::Reg64 reg_src = r8;
108 Xbyak::Reg64 reg_tr_src = r9;
109 Xbyak::Reg64 reg_src_prf = r10;
110 Xbyak::Reg64 reg_tr_src_prf = r11;
111 Xbyak::Reg64 reg_loop = r12;
112 Xbyak::Reg64 reg_tr_src_tmp = r13;
113 Xbyak::Reg32 regw_tmp = r14d;
114
115 void transpose_block(int ur, int nrows);
116 void transpose(int nrows);
117 void generate() override;
118};
119
120struct jit_diff_wei_trans_to_vnni_t : public jit_generator {
121 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_diff_wei_trans_to_vnni_t)
122
123 jit_diff_wei_trans_to_vnni_t(const data_type_t dt, const int &kd,
124 const int &kh, const int &kw, const int &ic_block,
125 const int &oc_block)
126 : jit_generator(jit_name())
127 , out_dt_(dt)
128 , kd_(kd)
129 , kh_(kh)
130 , kw_(kw)
131 , ic_block_(ic_block)
132 , oc_block_(oc_block) {}
133
134 ~jit_diff_wei_trans_to_vnni_t() {}
135
136 status_t create_kernel() override { return jit_generator::create_kernel(); }
137
138 const data_type_t out_dt_;
139 const int kd_, kh_, kw_;
140 const int ic_block_, oc_block_;
141
142private:
143 void generate() override;
144};
145
146jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf);
147jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf);
148
149} // namespace x64
150} // namespace cpu
151} // namespace impl
152} // namespace dnnl
153
154#endif
155