1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18#include <cassert>
19#include <set>
20
21#include "common/memory_desc_wrapper.hpp"
22#include "cpu/x64/prelu/jit_prelu_utils.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace cpu {
27namespace x64 {
28namespace prelu {
29
30cpu_isa_t get_supported_isa() {
31 if (mayiuse(avx512_core_fp16))
32 return avx512_core_fp16;
33 else if (mayiuse(avx512_core_bf16))
34 return avx512_core_bf16;
35 else if (mayiuse(avx512_core))
36 return avx512_core;
37 else if (mayiuse(avx2))
38 return avx2;
39 else if (mayiuse(avx))
40 return avx;
41 else if (mayiuse(sse41))
42 return sse41;
43
44 return isa_undef;
45}
46
47static int get_vlen(const cpu_isa_t &isa) noexcept {
48 if (isa == avx512_core_fp16)
49 return cpu_isa_traits<avx512_core_fp16>::vlen;
50 else if (isa == avx512_core_bf16)
51 return cpu_isa_traits<avx512_core_bf16>::vlen;
52 else if (isa == avx512_core)
53 return cpu_isa_traits<avx512_core>::vlen;
54 else if (isa == avx2)
55 return cpu_isa_traits<avx2>::vlen;
56 else if (isa == avx)
57 return cpu_isa_traits<avx>::vlen;
58 return cpu_isa_traits<sse41>::vlen;
59}
60
61int get_n_vregs(const cpu_isa_t &isa) noexcept {
62 if (isa == avx512_core_fp16)
63 return cpu_isa_traits<avx512_core_fp16>::n_vregs;
64 else if (isa == avx512_core_bf16)
65 return cpu_isa_traits<avx512_core_bf16>::n_vregs;
66 else if (isa == avx512_core)
67 return cpu_isa_traits<avx512_core>::n_vregs;
68 else if (isa == avx2)
69 return cpu_isa_traits<avx2>::n_vregs;
70 else if (isa == avx)
71 return cpu_isa_traits<avx>::n_vregs;
72 return cpu_isa_traits<sse41>::n_vregs;
73}
74
75bool is_s8u8(const std::set<data_type_t> &tensor_data_types) noexcept {
76 return std::any_of(tensor_data_types.cbegin(), tensor_data_types.cend(),
77 [](const data_type_t &dt) {
78 return utils::one_of(dt, data_type::s8, data_type::u8);
79 });
80}
81
82int get_simd_w(const std::set<data_type_t> &tensor_data_types) noexcept {
83 const auto &isa = prelu::get_supported_isa();
84
85 return (isa == avx && is_s8u8(tensor_data_types))
86 ? vreg_traits<Xbyak::Xmm>::vlen / sizeof(float)
87 : prelu::get_vlen(isa) / sizeof(float);
88}
89
90static bool dims_equal(
91 const dims_t &lhs_dims, const dims_t &rhs_dims, const dim_t ndims) {
92
93 for (dim_t i = 0; i < ndims; ++i) {
94 if (lhs_dims[i] != rhs_dims[i]) return false;
95 }
96
97 return true;
98}
99
100static bool is_full_bcast(
101 const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) {
102 const auto lhs_ndims = lhs.ndims();
103 const auto rhs_ndims = rhs.ndims();
104 const dims_t &lhs_dims = lhs.dims();
105 const dims_t &rhs_dims = rhs.dims();
106
107 if (lhs_ndims == rhs_ndims && dims_equal(lhs_dims, rhs_dims, lhs_ndims)
108 && lhs.format_kind() == rhs.format_kind()) {
109
110 if (lhs.is_blocking_desc()) {
111 const auto &lhs_bd = lhs.blocking_desc();
112 const auto &rhs_bd = rhs.blocking_desc();
113 const dims_t &lhs_strides = lhs_bd.strides;
114 const dims_t &rhs_strides = rhs_bd.strides;
115 const dims_t &lhs_inner_blks = lhs_bd.inner_blks;
116 const dims_t &rhs_inner_blks = rhs_bd.inner_blks;
117 const dims_t &lhs_inner_idxs = lhs_bd.inner_idxs;
118 const dims_t &rhs_inner_idxs = rhs_bd.inner_idxs;
119
120 return lhs_bd.inner_nblks == rhs_bd.inner_nblks
121 && dims_equal(lhs_strides, rhs_strides, lhs_ndims)
122 && dims_equal(lhs_inner_blks, rhs_inner_blks, lhs_ndims)
123 && dims_equal(lhs_inner_idxs, rhs_inner_idxs, lhs_ndims);
124 }
125
126 return true;
127 }
128
129 return false;
130}
131
132static bool is_per_oc_bcast(
133 const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) {
134
135 const auto &rhs_dims = rhs.dims();
136 const auto &lhs_dims = lhs.dims();
137 const auto &rhs_ndims = rhs.ndims();
138
139 bool bcast_per_oc_exists = rhs_dims[0] == 1 && rhs_dims[1] == lhs_dims[1];
140
141 if (bcast_per_oc_exists) {
142 for (int dim_id = 2; dim_id < rhs_ndims; ++dim_id) {
143 bcast_per_oc_exists = bcast_per_oc_exists && rhs_dims[dim_id] == 1;
144 }
145 }
146
147 return bcast_per_oc_exists;
148}
149
150bcast get_bcast_type(
151 const memory_desc_wrapper &lhs, const memory_desc_wrapper &rhs) {
152
153 if (is_full_bcast(lhs, rhs)) return bcast::full;
154 const auto &lhs_ndims = lhs.ndims();
155 const auto &rhs_ndims = rhs.ndims();
156
157 if (lhs_ndims != rhs_ndims || lhs_ndims < 2) return bcast::unsupported;
158
159 if (is_per_oc_bcast(lhs, rhs)) {
160 const auto &strides = lhs.blocking_desc().strides;
161
162 if (!lhs.is_plain())
163 return bcast::per_oc_blocked;
164 else if (strides[1] == 1)
165 return bcast::per_oc_n_spatial_c;
166 else if (strides[0] >= strides[1]
167 && IMPLICATION(lhs_ndims >= 3, strides[1] >= strides[2]))
168 return bcast::per_oc_n_c_spatial;
169 }
170
171 return bcast::unsupported;
172}
173
174bool dt_supported(const std::set<data_type_t> &tensor_data_types) noexcept {
175
176 const bool is_avx512_core = mayiuse(avx512_core);
177 const bool is_avx512_core_fp16 = mayiuse(avx512_core_fp16);
178
179 auto is_dt_ok = [&](data_type_t dt) {
180 return utils::one_of(dt, data_type::bf16, data_type::f16,
181 data_type::f32, data_type::s32, data_type::u8,
182 data_type::s8)
183 && IMPLICATION(dt == data_type::bf16, is_avx512_core)
184 && IMPLICATION(dt == data_type::f16, is_avx512_core_fp16);
185 };
186
187 for (auto dt : tensor_data_types)
188 if (!is_dt_ok(dt)) return false;
189
190 return true;
191}
192
193size_t c_blk_nelems(const memory_desc_t *mem, bool padding) noexcept {
194 const memory_desc_wrapper mem_d {mem};
195 return mem_d.nelems(padding) / mem_d.dims()[0];
196}
197
198size_t get_block_tail_size(const memory_desc_t *mem) noexcept {
199 const memory_desc_wrapper mem_d {mem};
200 return mem_d.padded_dims()[1] - mem_d.dims()[1];
201}
202
203void apply_zero_padding(jit_generator *host, const size_t tail_size,
204 const data_type_t dt, const size_t block_tail_size,
205 const Xbyak::Reg64 &reg_dst, const Xbyak::Reg64 *reg_offset) noexcept {
206 using namespace Xbyak;
207 using namespace Xbyak::util;
208
209 const Reg32 &reg_zero = eax;
210 const Reg64 &reg_ptr = rdi;
211 const Reg64 &reg_counter = rcx;
212 const auto dt_size = types::data_type_size(dt);
213 const auto off_start = tail_size * dt_size;
214 const auto off_end = off_start + block_tail_size * dt_size;
215
216 host->xor_(reg_zero, reg_zero);
217 if (reg_offset == nullptr)
218 host->lea(reg_ptr, ptr[reg_dst + off_start]);
219 else
220 host->lea(reg_ptr, ptr[reg_dst + (*reg_offset * dt_size) + off_start]);
221 host->mov(reg_counter, off_end - off_start);
222 host->rep();
223 host->stosb();
224}
225
226} // namespace prelu
227} // namespace x64
228} // namespace cpu
229} // namespace impl
230} // namespace dnnl
231