1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18
19#include "common/bfloat16.hpp"
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/x64/jit_generator.hpp"
25#include "cpu/x64/shuffle/jit_uni_shuffle.hpp"
26
27namespace dnnl {
28namespace impl {
29namespace cpu {
30namespace x64 {
31
32static bool impl_supports_datatype(cpu_isa_t isa, data_type_t data_type) {
33 switch (data_type) {
34 case data_type::bf16: return is_superset(isa, avx512_core);
35 case data_type::f16: return is_superset(isa, avx512_core_fp16);
36 case data_type::f32:
37 case data_type::s32:
38 case data_type::s8:
39 case data_type::u8: return true;
40 default: return false;
41 }
42}
43
44template <cpu_isa_t isa>
45status_t jit_uni_shuffle_t<isa>::pd_t::init(engine_t *engine) {
46 using namespace format_tag;
47 using namespace data_type;
48
49 const memory_desc_wrapper src_d(is_fwd() ? src_md() : diff_src_md());
50 const memory_desc_wrapper dst_d(is_fwd() ? dst_md() : diff_dst_md());
51
52 conf_.data_type = src_d.data_type();
53
54 const bool ok = mayiuse(isa)
55 && utils::one_of(conf_.data_type, f32, s32, bf16)
56 && src_d.data_type() == dst_d.data_type()
57 && impl_supports_datatype(isa, conf_.data_type)
58 && attr()->has_default_values() && axis() == 1
59 && set_default_formats_common() && src_d == dst_d;
60
61 if (!ok) return status::unimplemented;
62
63 conf_.isa = isa;
64 if (isa == avx) conf_.isa = mayiuse(avx2) ? avx2 : avx;
65 if (conf_.data_type == bf16)
66 conf_.isa = mayiuse(avx512_core_bf16) ? avx512_core_bf16 : avx512_core;
67
68 const format_tag_t blocked_format
69 = memory_desc_matches_one_of_tag(*src_d.md_, nCw16c, nChw16c,
70 nCdhw16c, nCw8c, nChw8c, nCdhw8c, nCw4c, nChw4c, nCdhw4c);
71
72 if (blocked_format == format_tag::undef) return status::unimplemented;
73
74 conf_.blk_size = src_d.blocking_desc().strides[ndims() - 1];
75 conf_.simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
76
77 const bool has_spatial = utils::one_of(ndims(), 3, 4, 5);
78 const dim_t HW = H() * W();
79 conf_.sp = has_spatial ? D() * HW : HW;
80
81 if (conf_.simd_w <= conf_.blk_size) {
82 conf_.tag_kind = jit_memory_tag_kind_t::blocked;
83 conf_.simd_tail = C() % conf_.simd_w;
84 conf_.c_split_size = conf_.blk_size;
85 if (C() < std::sqrt(conf_.sp))
86 conf_.sp_split_size
87 = conf_.sp / math::gcd(conf_.sp, dnnl_get_max_threads());
88 else
89 conf_.sp_split_size = conf_.sp;
90 } else
91 return status::unimplemented;
92
93 conf_.ndims = ndims();
94 conf_.mb = MB();
95 conf_.c = C();
96 conf_.d = D();
97 conf_.h = H();
98 conf_.w = W();
99
100 conf_.dt_size = types::data_type_size(conf_.data_type);
101 conf_.stride_mb = src_d.blocking_desc().strides[0];
102 conf_.group_size = group_size();
103 conf_.axis = axis();
104 conf_.axis_size = axis_size();
105 conf_.el_size_of_indices = sizeof(unsigned);
106
107 return status::success;
108}
109
110template <cpu_isa_t isa>
111status_t jit_uni_shuffle_t<isa>::precompute_offsets() {
112 const auto conf = pd()->get_conf();
113 const int axis_size = conf.axis_size;
114 const int group_size = conf.group_size;
115 const int transpose_row
116 = pd()->is_fwd() ? group_size : axis_size / group_size;
117 const int transpose_col
118 = pd()->is_fwd() ? axis_size / group_size : group_size;
119 std::vector<int> rev_transposed_(axis_size);
120
121 // Precompute transposed axis helper array
122 parallel_nd(transpose_col, transpose_row, [&](dim_t i, dim_t j) {
123 rev_transposed_[j * transpose_col + i] = i * transpose_row + j;
124 });
125
126 const dim_t C = conf.c;
127 input_off_ = (unsigned *)malloc(
128 C * sizeof(unsigned), platform::get_cache_line_size());
129 if (input_off_ == nullptr) return dnnl_out_of_memory;
130
131 if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) {
132 const dim_t blk_size = conf.blk_size;
133 const dim_t CB = utils::div_up(C, blk_size);
134 const dim_t SP = conf.sp;
135
136 // Precompute input offsets using transposed axis
137 parallel_nd(CB, [&](dim_t cb) {
138 const int blk_end = nstl::min(blk_size, C - cb * blk_size);
139 PRAGMA_OMP_SIMD()
140 for (int cc = 0; cc < blk_end; ++cc) {
141 const int off = cb * blk_size + cc;
142 const int &input_c = rev_transposed_[off];
143 input_off_[off] = (input_c / blk_size * SP * blk_size
144 + input_c % blk_size)
145 * conf.dt_size;
146 }
147 });
148 } else {
149 assert(!"Invalid memory format kind.");
150 return status::invalid_arguments;
151 }
152
153 return status::success;
154}
155
156template <cpu_isa_t isa>
157status_t jit_uni_shuffle_t<isa>::init(engine_t *engine) {
158 CHECK(precompute_offsets());
159 CHECK(safe_ptr_assign(
160 kernel_, new jit_uni_shuffle_kernel_t<isa>(pd()->get_conf())));
161 CHECK(kernel_->create_kernel());
162 return status::success;
163}
164
165template <cpu_isa_t isa>
166inline jit_uni_shuffle_t<isa>::jit_uni_shuffle_t(const pd_t *apd)
167 : primitive_t(apd) {}
168
169template <cpu_isa_t isa>
170jit_uni_shuffle_t<isa>::~jit_uni_shuffle_t() {
171 free(this->input_off_);
172}
173
174template <cpu_isa_t isa>
175status_t jit_uni_shuffle_t<isa>::execute(const exec_ctx_t &ctx) const {
176 using namespace prop_kind;
177 using namespace utils;
178
179 const auto i_arg = pd()->is_fwd() ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
180 const auto o_arg = pd()->is_fwd() ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
181 auto input = CTX_IN_MEM(const uint8_t *, i_arg);
182 auto output = CTX_OUT_MEM(uint8_t *, o_arg);
183
184 const auto conf = pd()->get_conf();
185
186 const dim_t MB = conf.mb;
187 const dim_t SP = conf.sp;
188 const dim_t C = conf.c;
189 const dim_t stride_mb = conf.stride_mb;
190 const int data_type_size = conf.dt_size;
191
192 if (pd()->get_conf().tag_kind == jit_memory_tag_kind_t::blocked) {
193 const dim_t CB = utils::div_up(C, conf.c_split_size);
194 const dim_t SPB = SP / conf.sp_split_size;
195 parallel_nd(MB, SPB, CB, [&](dim_t mb, dim_t spb, dim_t cb) {
196 const dim_t c_work
197 = nstl::min(conf.c_split_size, C - cb * conf.c_split_size);
198 const dim_t c_curr = cb * conf.c_split_size;
199 const dim_t sp_work = conf.sp_split_size;
200 const dim_t sp_curr = spb * sp_work;
201 const dim_t off = mb * stride_mb + sp_curr * conf.blk_size;
202
203 jit_shuffle_call_s args;
204 args.src = input + off * data_type_size;
205 args.dst = output + (off + SP * c_curr) * data_type_size;
206
207 args.cb_loop_size = c_work;
208 args.is_padded_block = cb + 1 == CB;
209
210 args.input_off_ptr = this->input_off_ + c_curr;
211 (*kernel_)(&args);
212 });
213 } else {
214 assert(!"Invalid memory format kind.");
215 return status::invalid_arguments;
216 }
217
218 return status::success;
219}
220
221template struct jit_uni_shuffle_t<sse41>;
222template struct jit_uni_shuffle_t<avx>;
223template struct jit_uni_shuffle_t<avx512_core>;
224
225} // namespace x64
226} // namespace cpu
227} // namespace impl
228} // namespace dnnl
229
230// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
231