1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | #include <math.h> |
19 | |
20 | #include "common/c_types_map.hpp" |
21 | #include "common/dnnl_thread.hpp" |
22 | #include "common/type_helpers.hpp" |
23 | |
24 | #include "cpu/ref_shuffle.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace cpu { |
29 | |
30 | using namespace format_tag; |
31 | |
32 | template <int data_type_size> |
33 | status_t ref_shuffle_t::execute_(const exec_ctx_t &ctx) const { |
34 | using namespace prop_kind; |
35 | using namespace utils; |
36 | using data_t = typename typesize_traits<data_type_size>::type; |
37 | |
38 | const memory_desc_wrapper src_d( |
39 | pd()->is_fwd() ? pd()->src_md() : pd()->diff_src_md()); |
40 | |
41 | status_t status = status::success; |
42 | auto i_arg = pd()->is_fwd() ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST; |
43 | auto o_arg = pd()->is_fwd() ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC; |
44 | auto input = CTX_IN_MEM(const data_t *, i_arg); |
45 | auto output = CTX_OUT_CLEAN_MEM(data_t *, o_arg, status); |
46 | CHECK(status); |
47 | |
48 | const int axis = pd()->axis(); |
49 | const int axis_size = pd()->axis_size(); |
50 | |
51 | const dim_t MB = pd()->MB(); |
52 | const dim_t C = pd()->C(); |
53 | dim_t H = 1, W = 1, D = 1, HW = 1, SP = 1; |
54 | const bool has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5); |
55 | if (has_spatial) { |
56 | D = pd()->D(); |
57 | H = pd()->H(); |
58 | W = pd()->W(); |
59 | HW = H * W; |
60 | SP = D * HW; |
61 | } |
62 | const dim_t stride_mb = src_d.blocking_desc().strides[0]; |
63 | const dim_t blksize = src_d.blocking_desc().strides[pd()->ndims() - 1]; |
64 | const format_tag_t tag = pd()->dat_tag_; |
65 | |
66 | if (axis == 1 |
67 | && one_of( |
68 | tag, nChw16c, nChw8c, nChw4c, nCdhw16c, nCdhw8c, nCdhw4c)) { |
69 | #if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP |
70 | #pragma omp parallel for collapse(3) schedule(static) |
71 | for_(dim_t mb = 0; mb < MB; ++mb) |
72 | for_(dim_t cb = 0; cb < C; cb += blksize) |
73 | for (dim_t sp = 0; sp < SP; ++sp) { |
74 | const dim_t off = mb * stride_mb + sp * blksize; |
75 | const dim_t output_off = off + cb * SP; |
76 | PRAGMA_OMP_SIMD() |
77 | for (dim_t cc = 0; cc < nstl::min(blksize, C - cb); ++cc) { |
78 | const dim_t input_c = rev_transposed_[cb + cc]; |
79 | const dim_t input_off = off + input_c / blksize * SP * blksize |
80 | + input_c % blksize; |
81 | output[output_off + cc] = input[input_off]; |
82 | } |
83 | } |
84 | #else |
85 | parallel_nd(MB, utils::div_up(C, blksize), SP, |
86 | [&](dim_t mb, dim_t c, dim_t sp) { |
87 | const dim_t off = mb * stride_mb + sp * blksize; |
88 | const dim_t cb = c * blksize; |
89 | const dim_t output_off = off + cb * SP; |
90 | PRAGMA_OMP_SIMD() |
91 | for (dim_t cc = 0; cc < nstl::min(blksize, C - cb); ++cc) { |
92 | const dim_t input_c = rev_transposed_[cb + cc]; |
93 | const dim_t input_off = off |
94 | + input_c / blksize * SP * blksize |
95 | + input_c % blksize; |
96 | output[output_off + cc] = input[input_off]; |
97 | } |
98 | }); |
99 | #endif |
100 | } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) { |
101 | parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) { |
102 | const dim_t off = mb * stride_mb + sp * C; |
103 | PRAGMA_OMP_SIMD() |
104 | for (dim_t c = 0; c < C; ++c) |
105 | output[off + c] = input[off + rev_transposed_[c]]; |
106 | }); |
107 | } else if (axis == 1 && one_of(tag, nchw, ncdhw)) { |
108 | parallel_nd(MB, C, [&](dim_t mb, dim_t c) { |
109 | const dim_t output_off = mb * stride_mb + c * SP; |
110 | const dim_t input_off = mb * stride_mb + rev_transposed_[c] * SP; |
111 | PRAGMA_OMP_SIMD() |
112 | for (dim_t sp = 0; sp < SP; ++sp) { |
113 | output[output_off + sp] = input[input_off + sp]; |
114 | } |
115 | }); |
116 | } else { |
117 | auto dims = pd()->desc()->src_desc.dims; |
118 | auto ndims = pd()->ndims(); |
119 | const dim_t outer_size = utils::array_product(dims, axis); |
120 | const dim_t inner_size |
121 | = utils::array_product(dims + axis + 1, ndims - axis - 1); |
122 | const dim_t dim = axis_size * inner_size; |
123 | |
124 | parallel_nd(outer_size, axis_size, inner_size, |
125 | [&](dim_t ou, dim_t a, dim_t in) { |
126 | const dim_t off = ou * dim + in; |
127 | auto &o = output[src_d.off_l(off + a * inner_size)]; |
128 | o = input[src_d.off_l( |
129 | off + rev_transposed_[a] * inner_size)]; |
130 | }); |
131 | } |
132 | return status::success; |
133 | } |
134 | |
135 | template status_t ref_shuffle_t::execute_<sizeof(float)>( |
136 | const exec_ctx_t &ctx) const; |
137 | template status_t ref_shuffle_t::execute_<sizeof(bfloat16_t)>( |
138 | const exec_ctx_t &ctx) const; |
139 | template status_t ref_shuffle_t::execute_<sizeof(int8_t)>( |
140 | const exec_ctx_t &ctx) const; |
141 | |
142 | } // namespace cpu |
143 | } // namespace impl |
144 | } // namespace dnnl |
145 | |
146 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
147 | |