1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cassert> |
18 | |
19 | #include "common/bfloat16.hpp" |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/x64/jit_generator.hpp" |
25 | #include "cpu/x64/shuffle/jit_uni_shuffle.hpp" |
26 | |
27 | namespace dnnl { |
28 | namespace impl { |
29 | namespace cpu { |
30 | namespace x64 { |
31 | |
32 | static bool impl_supports_datatype(cpu_isa_t isa, data_type_t data_type) { |
33 | switch (data_type) { |
34 | case data_type::bf16: return is_superset(isa, avx512_core); |
35 | case data_type::f16: return is_superset(isa, avx512_core_fp16); |
36 | case data_type::f32: |
37 | case data_type::s32: |
38 | case data_type::s8: |
39 | case data_type::u8: return true; |
40 | default: return false; |
41 | } |
42 | } |
43 | |
44 | template <cpu_isa_t isa> |
45 | status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) { |
46 | using namespace format_tag; |
47 | using namespace data_type; |
48 | |
49 | const memory_desc_wrapper src_d(is_fwd() ? src_md() : diff_src_md()); |
50 | const memory_desc_wrapper dst_d(is_fwd() ? dst_md() : diff_dst_md()); |
51 | |
52 | conf_.data_type = src_d.data_type(); |
53 | |
54 | const bool ok = mayiuse(isa) |
55 | && utils::one_of(conf_.data_type, f32, s32, bf16) |
56 | && src_d.data_type() == dst_d.data_type() |
57 | && impl_supports_datatype(isa, conf_.data_type) |
58 | && attr()->has_default_values() && axis() == 1 |
59 | && set_default_formats_common() && src_d == dst_d; |
60 | |
61 | if (!ok) return status::unimplemented; |
62 | |
63 | conf_.isa = isa; |
64 | if (isa == avx) conf_.isa = mayiuse(avx2) ? avx2 : avx; |
65 | if (conf_.data_type == bf16) |
66 | conf_.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core; |
67 | |
68 | const format_tag_t blocked_format |
69 | = memory_desc_matches_one_of_tag(*src_d.md_, nCw16c, nChw16c, |
70 | nCdhw16c, nCw8c, nChw8c, nCdhw8c, nCw4c, nChw4c, nCdhw4c); |
71 | |
72 | if (blocked_format == format_tag::undef) return status::unimplemented; |
73 | |
74 | conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1]; |
75 | conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(float); |
76 | |
77 | const bool has_spatial = utils::one_of(ndims(), 3, 4, 5); |
78 | const dim_t HW = H() * W(); |
79 | conf_.sp = has_spatial ? D() * HW : HW; |
80 | |
81 | if (conf_.simd_w <= conf_.blk_size) { |
82 | conf_.tag_kind = jit_memory_tag_kind_t::blocked; |
83 | conf_.simd_tail = C() % conf_.simd_w; |
84 | conf_.c_split_size = conf_.blk_size; |
85 | if (C() < std::sqrt(conf_.sp)) |
86 | conf_.sp_split_size |
87 | = conf_.sp / math::gcd(conf_.sp, dnnl_get_max_threads()); |
88 | else |
89 | conf_.sp_split_size = conf_.sp; |
90 | } else |
91 | return status::unimplemented; |
92 | |
93 | conf_.ndims = ndims(); |
94 | conf_.mb = MB(); |
95 | conf_.c = C(); |
96 | conf_.d = D(); |
97 | conf_.h = H(); |
98 | conf_.w = W(); |
99 | |
100 | conf_.dt_size = types::data_type_size(conf_.data_type); |
101 | conf_.stride_mb = src_d.blocking_desc().strides[0]; |
102 | conf_.group_size = group_size(); |
103 | conf_.axis = axis(); |
104 | conf_.axis_size = axis_size(); |
105 | conf_.el_size_of_indices = sizeof(unsigned); |
106 | |
107 | return status::success; |
108 | } |
109 | |
110 | template <cpu_isa_t isa> |
111 | status_t jit_uni_shuffle_t<isa>::precompute_offsets() { |
112 | const auto conf = pd()->get_conf(); |
113 | const int axis_size = conf.axis_size; |
114 | const int group_size = conf.group_size; |
115 | const int transpose_row |
116 | = pd()->is_fwd() ? group_size : axis_size / group_size; |
117 | const int transpose_col |
118 | = pd()->is_fwd() ? axis_size / group_size : group_size; |
119 | std::vector<int> rev_transposed_(axis_size); |
120 | |
121 | // Precompute transposed axis helper array |
122 | parallel_nd(transpose_col, transpose_row, [&](dim_t i, dim_t j) { |
123 | rev_transposed_[j * transpose_col + i] = i * transpose_row + j; |
124 | }); |
125 | |
126 | const dim_t C = conf.c; |
127 | input_off_ = (unsigned *)malloc( |
128 | C * sizeof(unsigned), platform::get_cache_line_size()); |
129 | if (input_off_ == nullptr) return dnnl_out_of_memory; |
130 | |
131 | if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) { |
132 | const dim_t blk_size = conf.blk_size; |
133 | const dim_t CB = utils::div_up(C, blk_size); |
134 | const dim_t SP = conf.sp; |
135 | |
136 | // Precompute input offsets using transposed axis |
137 | parallel_nd(CB, [&](dim_t cb) { |
138 | const int blk_end = nstl::min(blk_size, C - cb * blk_size); |
139 | PRAGMA_OMP_SIMD() |
140 | for (int cc = 0; cc < blk_end; ++cc) { |
141 | const int off = cb * blk_size + cc; |
142 | const int &input_c = rev_transposed_[off]; |
143 | input_off_[off] = (input_c / blk_size * SP * blk_size |
144 | + input_c % blk_size) |
145 | * conf.dt_size; |
146 | } |
147 | }); |
148 | } else { |
149 | assert(!"Invalid memory format kind." ); |
150 | return status::invalid_arguments; |
151 | } |
152 | |
153 | return status::success; |
154 | } |
155 | |
156 | template <cpu_isa_t isa> |
157 | status_t jit_uni_shuffle_t<isa>::init(engine_t *engine) { |
158 | CHECK(precompute_offsets()); |
159 | CHECK(safe_ptr_assign( |
160 | kernel_, new jit_uni_shuffle_kernel_t<isa>(pd()->get_conf()))); |
161 | CHECK(kernel_->create_kernel()); |
162 | return status::success; |
163 | } |
164 | |
165 | template <cpu_isa_t isa> |
166 | inline jit_uni_shuffle_t<isa>::jit_uni_shuffle_t(const pd_t *apd) |
167 | : primitive_t(apd) {} |
168 | |
169 | template <cpu_isa_t isa> |
170 | jit_uni_shuffle_t<isa>::~jit_uni_shuffle_t() { |
171 | free(this->input_off_); |
172 | } |
173 | |
174 | template <cpu_isa_t isa> |
175 | status_t jit_uni_shuffle_t<isa>::execute(const exec_ctx_t &ctx) const { |
176 | using namespace prop_kind; |
177 | using namespace utils; |
178 | |
179 | const auto i_arg = pd()->is_fwd() ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST; |
180 | const auto o_arg = pd()->is_fwd() ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC; |
181 | auto input = CTX_IN_MEM(const uint8_t *, i_arg); |
182 | auto output = CTX_OUT_MEM(uint8_t *, o_arg); |
183 | |
184 | const auto conf = pd()->get_conf(); |
185 | |
186 | const dim_t MB = conf.mb; |
187 | const dim_t SP = conf.sp; |
188 | const dim_t C = conf.c; |
189 | const dim_t stride_mb = conf.stride_mb; |
190 | const int data_type_size = conf.dt_size; |
191 | |
192 | if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) { |
193 | const dim_t CB = utils::div_up(C, conf.c_split_size); |
194 | const dim_t SPB = SP / conf.sp_split_size; |
195 | parallel_nd(MB, SPB, CB, [&](dim_t mb, dim_t spb, dim_t cb) { |
196 | const dim_t c_work |
197 | = nstl::min(conf.c_split_size, C - cb * conf.c_split_size); |
198 | const dim_t c_curr = cb * conf.c_split_size; |
199 | const dim_t sp_work = conf.sp_split_size; |
200 | const dim_t sp_curr = spb * sp_work; |
201 | const dim_t off = mb * stride_mb + sp_curr * conf.blk_size; |
202 | |
203 | jit_shuffle_call_s args; |
204 | args.src = input + off * data_type_size; |
205 | args.dst = output + (off + SP * c_curr) * data_type_size; |
206 | |
207 | args.cb_loop_size = c_work; |
208 | args.is_padded_block = cb + 1 == CB; |
209 | |
210 | args.input_off_ptr = this->input_off_ + c_curr; |
211 | (*kernel_)(&args); |
212 | }); |
213 | } else { |
214 | assert(!"Invalid memory format kind." ); |
215 | return status::invalid_arguments; |
216 | } |
217 | |
218 | return status::success; |
219 | } |
220 | |
221 | template struct jit_uni_shuffle_t<sse41>; |
222 | template struct jit_uni_shuffle_t<avx>; |
223 | template struct jit_uni_shuffle_t<avx512_core>; |
224 | |
225 | } // namespace x64 |
226 | } // namespace cpu |
227 | } // namespace impl |
228 | } // namespace dnnl |
229 | |
230 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
231 | |