1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | #include <cassert> |
19 | #include <set> |
20 | |
21 | #include "common/memory_desc_wrapper.hpp" |
22 | #include "cpu/x64/prelu/jit_prelu_utils.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace cpu { |
27 | namespace x64 { |
28 | namespace prelu { |
29 | |
30 | cpu_isa_t get_supported_isa() { |
31 | if (mayiuse(avx512_core_fp16)) |
32 | return avx512_core_fp16; |
33 | else if (mayiuse(avx512_core_bf16)) |
34 | return avx512_core_bf16; |
35 | else if (mayiuse(avx512_core)) |
36 | return avx512_core; |
37 | else if (mayiuse(avx2)) |
38 | return avx2; |
39 | else if (mayiuse(avx)) |
40 | return avx; |
41 | else if (mayiuse(sse41)) |
42 | return sse41; |
43 | |
44 | return isa_undef; |
45 | } |
46 | |
47 | static int get_vlen(const cpu_isa_t &isa) noexcept { |
48 | if (isa == avx512_core_fp16) |
49 | return cpu_isa_traits<avx512_core_fp16>::vlen; |
50 | else if (isa == avx512_core_bf16) |
51 | return cpu_isa_traits<avx512_core_bf16>::vlen; |
52 | else if (isa == avx512_core) |
53 | return cpu_isa_traits<avx512_core>::vlen; |
54 | else if (isa == avx2) |
55 | return cpu_isa_traits<avx2>::vlen; |
56 | else if (isa == avx) |
57 | return cpu_isa_traits<avx>::vlen; |
58 | return cpu_isa_traits<sse41>::vlen; |
59 | } |
60 | |
61 | int get_n_vregs(const cpu_isa_t &isa) noexcept { |
62 | if (isa == avx512_core_fp16) |
63 | return cpu_isa_traits<avx512_core_fp16>::n_vregs; |
64 | else if (isa == avx512_core_bf16) |
65 | return cpu_isa_traits<avx512_core_bf16>::n_vregs; |
66 | else if (isa == avx512_core) |
67 | return cpu_isa_traits<avx512_core>::n_vregs; |
68 | else if (isa == avx2) |
69 | return cpu_isa_traits<avx2>::n_vregs; |
70 | else if (isa == avx) |
71 | return cpu_isa_traits<avx>::n_vregs; |
72 | return cpu_isa_traits<sse41>::n_vregs; |
73 | } |
74 | |
75 | bool is_s8u8(const std::set<data_type_t> &tensor_data_types) noexcept { |
76 | return std::any_of(tensor_data_types.cbegin(), tensor_data_types.cend(), |
77 | [](const data_type_t &dt) { |
78 | return utils::one_of(dt, data_type::s8, data_type::u8); |
79 | }); |
80 | } |
81 | |
82 | int get_simd_w(const std::set<data_type_t> &tensor_data_types) noexcept { |
83 | const auto &isa = prelu::get_supported_isa(); |
84 | |
85 | return (isa == avx && is_s8u8(tensor_data_types)) |
86 | ? vreg_traits<Xbyak::Xmm>::vlen / sizeof(float) |
87 | : prelu::get_vlen(isa) / sizeof(float); |
88 | } |
89 | |
90 | static bool dims_equal( |
91 | const dims_t &lhs_dims, const dims_t &rhs_dims, const dim_t ndims) { |
92 | |
93 | for (dim_t i = 0; i < ndims; ++i) { |
94 | if (lhs_dims[i] != rhs_dims[i]) return false; |
95 | } |
96 | |
97 | return true; |
98 | } |
99 | |
100 | static bool is_full_bcast( |
101 | const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) { |
102 | const auto lhs_ndims = lhs.ndims(); |
103 | const auto rhs_ndims = rhs.ndims(); |
104 | const dims_t &lhs_dims = lhs.dims(); |
105 | const dims_t &rhs_dims = rhs.dims(); |
106 | |
107 | if (lhs_ndims == rhs_ndims && dims_equal(lhs_dims, rhs_dims, lhs_ndims) |
108 | && lhs.format_kind() == rhs.format_kind()) { |
109 | |
110 | if (lhs.is_blocking_desc()) { |
111 | const auto &lhs_bd = lhs.blocking_desc(); |
112 | const auto &rhs_bd = rhs.blocking_desc(); |
113 | const dims_t &lhs_strides = lhs_bd.strides; |
114 | const dims_t &rhs_strides = rhs_bd.strides; |
115 | const dims_t &lhs_inner_blks = lhs_bd.inner_blks; |
116 | const dims_t &rhs_inner_blks = rhs_bd.inner_blks; |
117 | const dims_t &lhs_inner_idxs = lhs_bd.inner_idxs; |
118 | const dims_t &rhs_inner_idxs = rhs_bd.inner_idxs; |
119 | |
120 | return lhs_bd.inner_nblks == rhs_bd.inner_nblks |
121 | && dims_equal(lhs_strides, rhs_strides, lhs_ndims) |
122 | && dims_equal(lhs_inner_blks, rhs_inner_blks, lhs_ndims) |
123 | && dims_equal(lhs_inner_idxs, rhs_inner_idxs, lhs_ndims); |
124 | } |
125 | |
126 | return true; |
127 | } |
128 | |
129 | return false; |
130 | } |
131 | |
132 | static bool is_per_oc_bcast( |
133 | const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) { |
134 | |
135 | const auto &rhs_dims = rhs.dims(); |
136 | const auto &lhs_dims = lhs.dims(); |
137 | const auto &rhs_ndims = rhs.ndims(); |
138 | |
139 | bool bcast_per_oc_exists = rhs_dims[0] == 1 && rhs_dims[1] == lhs_dims[1]; |
140 | |
141 | if (bcast_per_oc_exists) { |
142 | for (int dim_id = 2; dim_id < rhs_ndims; ++dim_id) { |
143 | bcast_per_oc_exists = bcast_per_oc_exists && rhs_dims[dim_id] == 1; |
144 | } |
145 | } |
146 | |
147 | return bcast_per_oc_exists; |
148 | } |
149 | |
150 | bcast get_bcast_type( |
151 | const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) { |
152 | |
153 | if (is_full_bcast(lhs, rhs)) return bcast::full; |
154 | const auto &lhs_ndims = lhs.ndims(); |
155 | const auto &rhs_ndims = rhs.ndims(); |
156 | |
157 | if (lhs_ndims != rhs_ndims || lhs_ndims < 2) return bcast::unsupported; |
158 | |
159 | if (is_per_oc_bcast(lhs, rhs)) { |
160 | const auto &strides = lhs.blocking_desc().strides; |
161 | |
162 | if (!lhs.is_plain()) |
163 | return bcast::per_oc_blocked; |
164 | else if (strides[1] == 1) |
165 | return bcast::per_oc_n_spatial_c; |
166 | else if (strides[0] >= strides[1] |
167 | && IMPLICATION(lhs_ndims >= 3, strides[1] >= strides[2])) |
168 | return bcast::per_oc_n_c_spatial; |
169 | } |
170 | |
171 | return bcast::unsupported; |
172 | } |
173 | |
174 | bool dt_supported(const std::set<data_type_t> &tensor_data_types) noexcept { |
175 | |
176 | const bool is_avx512_core = mayiuse(avx512_core); |
177 | const bool is_avx512_core_fp16 = mayiuse(avx512_core_fp16); |
178 | |
179 | auto is_dt_ok = [&](data_type_t dt) { |
180 | return utils::one_of(dt, data_type::bf16, data_type::f16, |
181 | data_type::f32, data_type::s32, data_type::u8, |
182 | data_type::s8) |
183 | && IMPLICATION(dt == data_type::bf16, is_avx512_core) |
184 | && IMPLICATION(dt == data_type::f16, is_avx512_core_fp16); |
185 | }; |
186 | |
187 | for (auto dt : tensor_data_types) |
188 | if (!is_dt_ok(dt)) return false; |
189 | |
190 | return true; |
191 | } |
192 | |
193 | size_t c_blk_nelems(const memory_desc_t *mem, bool padding) noexcept { |
194 | const memory_desc_wrapper mem_d {mem}; |
195 | return mem_d.nelems(padding) / mem_d.dims()[0]; |
196 | } |
197 | |
198 | size_t get_block_tail_size(const memory_desc_t *mem) noexcept { |
199 | const memory_desc_wrapper mem_d {mem}; |
200 | return mem_d.padded_dims()[1] - mem_d.dims()[1]; |
201 | } |
202 | |
203 | void apply_zero_padding(jit_generator *host, const size_t tail_size, |
204 | const data_type_t dt, const size_t block_tail_size, |
205 | const Xbyak::Reg64 ®_dst, const Xbyak::Reg64 *reg_offset) noexcept { |
206 | using namespace Xbyak; |
207 | using namespace Xbyak::util; |
208 | |
209 | const Reg32 ®_zero = eax; |
210 | const Reg64 ®_ptr = rdi; |
211 | const Reg64 ®_counter = rcx; |
212 | const auto dt_size = types::data_type_size(dt); |
213 | const auto off_start = tail_size * dt_size; |
214 | const auto off_end = off_start + block_tail_size * dt_size; |
215 | |
216 | host->xor_(reg_zero, reg_zero); |
217 | if (reg_offset == nullptr) |
218 | host->lea(reg_ptr, ptr[reg_dst + off_start]); |
219 | else |
220 | host->lea(reg_ptr, ptr[reg_dst + (*reg_offset * dt_size) + off_start]); |
221 | host->mov(reg_counter, off_end - off_start); |
222 | host->rep(); |
223 | host->stosb(); |
224 | } |
225 | |
226 | } // namespace prelu |
227 | } // namespace x64 |
228 | } // namespace cpu |
229 | } // namespace impl |
230 | } // namespace dnnl |
231 | |