1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "shuffle/shuffle.hpp" |
18 | #include "utils/parallel.hpp" |
19 | |
20 | namespace shuffle { |
21 | |
22 | void compute_ref( |
23 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
24 | const int src_arg = prb->dir == FWD_D ? DNNL_ARG_SRC : DNNL_ARG_DIFF_DST; |
25 | const int dst_arg = prb->dir == FWD_D ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC; |
26 | const dnn_mem_t &src = args.find(src_arg); |
27 | const dnn_mem_t &dst = args.find(dst_arg); |
28 | |
29 | float *dst_ptr = (float *)dst; |
30 | |
31 | const int axis = prb->axis; |
32 | const int64_t group_size = prb->group; |
33 | const int64_t axis_size = prb->dims[axis]; |
34 | int64_t inner_size = 1, outer_size = 1; |
35 | |
36 | auto transpose = [=](int64_t a) { |
37 | int64_t R, C; |
38 | if (prb->dir == FWD_D) { |
39 | R = group_size; |
40 | C = axis_size / group_size; |
41 | } else { |
42 | R = axis_size / group_size; |
43 | C = group_size; |
44 | } |
45 | int64_t col = a / R; |
46 | int64_t row = a % R; |
47 | return C * row + col; |
48 | }; |
49 | |
50 | for (int i = 0; i < axis; ++i) |
51 | outer_size *= (size_t)prb->dims[i]; |
52 | for (int i = axis + 1; i < prb->ndims; ++i) |
53 | inner_size *= (size_t)prb->dims[i]; |
54 | const size_t dim = axis_size * inner_size; |
55 | |
56 | benchdnn_parallel_nd(outer_size, axis_size, inner_size, |
57 | [&](int64_t ou, int64_t a, int64_t in) { |
58 | auto src_off = ou * dim + a * inner_size + in; |
59 | auto dst_off = ou * dim + transpose(a) * inner_size + in; |
60 | dst_ptr[dst_off] = src.get_elem(src_off); |
61 | }); |
62 | } |
63 | |
64 | } // namespace shuffle |
65 | |