1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "shuffle/shuffle.hpp"
18#include "utils/parallel.hpp"
19
20namespace shuffle {
21
22void compute_ref(
23 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
24 const int src_arg = prb->dir == FWD_D ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST;
25 const int dst_arg = prb->dir == FWD_D ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
26 const dnn_mem_t &src = args.find(src_arg);
27 const dnn_mem_t &dst = args.find(dst_arg);
28
29 float *dst_ptr = (float *)dst;
30
31 const int axis = prb->axis;
32 const int64_t group_size = prb->group;
33 const int64_t axis_size = prb->dims[axis];
34 int64_t inner_size = 1, outer_size = 1;
35
36 auto transpose = [=](int64_t a) {
37 int64_t R, C;
38 if (prb->dir == FWD_D) {
39 R = group_size;
40 C = axis_size / group_size;
41 } else {
42 R = axis_size / group_size;
43 C = group_size;
44 }
45 int64_t col = a / R;
46 int64_t row = a % R;
47 return C * row + col;
48 };
49
50 for (int i = 0; i < axis; ++i)
51 outer_size *= (size_t)prb->dims[i];
52 for (int i = axis + 1; i < prb->ndims; ++i)
53 inner_size *= (size_t)prb->dims[i];
54 const size_t dim = axis_size * inner_size;
55
56 benchdnn_parallel_nd(outer_size, axis_size, inner_size,
57 [&](int64_t ou, int64_t a, int64_t in) {
58 auto src_off = ou * dim + a * inner_size + in;
59 auto dst_off = ou * dim + transpose(a) * inner_size + in;
60 dst_ptr[dst_off] = src.get_elem(src_off);
61 });
62}
63
64} // namespace shuffle
65