1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "utils/parallel.hpp" |
18 | |
19 | #include "concat/concat.hpp" |
20 | |
21 | namespace concat { |
22 | |
23 | void get_sizes(const prb_t *prb, int64_t &outer_size, int64_t &inner_size, |
24 | int64_t &axis_size) { |
25 | outer_size = inner_size = 1; |
26 | for (int i = 0; i < prb->axis; i++) |
27 | outer_size *= prb->vdims[0][i]; |
28 | for (int i = prb->axis + 1; i < prb->ndims; i++) |
29 | inner_size *= prb->vdims[0][i]; |
30 | axis_size = prb->axis_size(); |
31 | } |
32 | |
33 | void compute_ref( |
34 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
35 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
36 | |
37 | float *dst_ptr = (float *)dst; |
38 | |
39 | int64_t outer_size {0}, inner_size {0}, axis_size {0}; |
40 | get_sizes(prb, outer_size, inner_size, axis_size); |
41 | |
42 | benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) { |
43 | int64_t off_dst = ou * axis_size * inner_size; |
44 | for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) { |
45 | const dnn_mem_t &src_i = args.find(DNNL_ARG_MULTIPLE_SRC + i_input); |
46 | int64_t i_axis_size = prb->vdims[i_input][prb->axis]; |
47 | int64_t off_src = ou * i_axis_size * inner_size; |
48 | |
49 | float scale_i |
50 | = prb->attr.scales.get(DNNL_ARG_MULTIPLE_SRC + i_input) |
51 | .scale; |
52 | |
53 | for (int64_t as = 0; as < i_axis_size; ++as) { |
54 | int64_t idx = as * inner_size + in; |
55 | dst_ptr[off_dst + idx] |
56 | = src_i.get_elem(off_src + idx) * scale_i; |
57 | } |
58 | // the next input start point |
59 | off_dst += i_axis_size * inner_size; |
60 | } |
61 | }); |
62 | } |
63 | |
64 | } // namespace concat |
65 | |