1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include <math.h>
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/ref_shuffle.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29
30using namespace format_tag;
31
32template <int data_type_size>
33status_t ref_shuffle_t::execute_(const exec_ctx_t &ctx) const {
34 using namespace prop_kind;
35 using namespace utils;
36 using data_t = typename typesize_traits<data_type_size>::type;
37
38 const memory_desc_wrapper src_d(
39 pd()->is_fwd() ? pd()->src_md() : pd()->diff_src_md());
40
41 status_t status = status::success;
42 auto i_arg = pd()->is_fwd() ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
43 auto o_arg = pd()->is_fwd() ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
44 auto input = CTX_IN_MEM(const data_t *, i_arg);
45 auto output = CTX_OUT_CLEAN_MEM(data_t *, o_arg, status);
46 CHECK(status);
47
48 const int axis = pd()->axis();
49 const int axis_size = pd()->axis_size();
50
51 const dim_t MB = pd()->MB();
52 const dim_t C = pd()->C();
53 dim_t H = 1, W = 1, D = 1, HW = 1, SP = 1;
54 const bool has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5);
55 if (has_spatial) {
56 D = pd()->D();
57 H = pd()->H();
58 W = pd()->W();
59 HW = H * W;
60 SP = D * HW;
61 }
62 const dim_t stride_mb = src_d.blocking_desc().strides[0];
63 const dim_t blksize = src_d.blocking_desc().strides[pd()->ndims() - 1];
64 const format_tag_t tag = pd()->dat_tag_;
65
66 if (axis == 1
67 && one_of(
68 tag, nChw16c, nChw8c, nChw4c, nCdhw16c, nCdhw8c, nCdhw4c)) {
69#if DNNL_CPU_THREADING_RUNTIME == DNNL_RUNTIME_OMP
70#pragma omp parallel for collapse(3) schedule(static)
71 for_(dim_t mb = 0; mb < MB; ++mb)
72 for_(dim_t cb = 0; cb < C; cb += blksize)
73 for (dim_t sp = 0; sp < SP; ++sp) {
74 const dim_t off = mb * stride_mb + sp * blksize;
75 const dim_t output_off = off + cb * SP;
76 PRAGMA_OMP_SIMD()
77 for (dim_t cc = 0; cc < nstl::min(blksize, C - cb); ++cc) {
78 const dim_t input_c = rev_transposed_[cb + cc];
79 const dim_t input_off = off + input_c / blksize * SP * blksize
80 + input_c % blksize;
81 output[output_off + cc] = input[input_off];
82 }
83 }
84#else
85 parallel_nd(MB, utils::div_up(C, blksize), SP,
86 [&](dim_t mb, dim_t c, dim_t sp) {
87 const dim_t off = mb * stride_mb + sp * blksize;
88 const dim_t cb = c * blksize;
89 const dim_t output_off = off + cb * SP;
90 PRAGMA_OMP_SIMD()
91 for (dim_t cc = 0; cc < nstl::min(blksize, C - cb); ++cc) {
92 const dim_t input_c = rev_transposed_[cb + cc];
93 const dim_t input_off = off
94 + input_c / blksize * SP * blksize
95 + input_c % blksize;
96 output[output_off + cc] = input[input_off];
97 }
98 });
99#endif
100 } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) {
101 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
102 const dim_t off = mb * stride_mb + sp * C;
103 PRAGMA_OMP_SIMD()
104 for (dim_t c = 0; c < C; ++c)
105 output[off + c] = input[off + rev_transposed_[c]];
106 });
107 } else if (axis == 1 && one_of(tag, nchw, ncdhw)) {
108 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
109 const dim_t output_off = mb * stride_mb + c * SP;
110 const dim_t input_off = mb * stride_mb + rev_transposed_[c] * SP;
111 PRAGMA_OMP_SIMD()
112 for (dim_t sp = 0; sp < SP; ++sp) {
113 output[output_off + sp] = input[input_off + sp];
114 }
115 });
116 } else {
117 auto dims = pd()->desc()->src_desc.dims;
118 auto ndims = pd()->ndims();
119 const dim_t outer_size = utils::array_product(dims, axis);
120 const dim_t inner_size
121 = utils::array_product(dims + axis + 1, ndims - axis - 1);
122 const dim_t dim = axis_size * inner_size;
123
124 parallel_nd(outer_size, axis_size, inner_size,
125 [&](dim_t ou, dim_t a, dim_t in) {
126 const dim_t off = ou * dim + in;
127 auto &o = output[src_d.off_l(off + a * inner_size)];
128 o = input[src_d.off_l(
129 off + rev_transposed_[a] * inner_size)];
130 });
131 }
132 return status::success;
133}
134
135template status_t ref_shuffle_t::execute_<sizeof(float)>(
136 const exec_ctx_t &ctx) const;
137template status_t ref_shuffle_t::execute_<sizeof(bfloat16_t)>(
138 const exec_ctx_t &ctx) const;
139template status_t ref_shuffle_t::execute_<sizeof(int8_t)>(
140 const exec_ctx_t &ctx) const;
141
142} // namespace cpu
143} // namespace impl
144} // namespace dnnl
145
146// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
147