1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "utils/parallel.hpp"
18
19#include "concat/concat.hpp"
20
21namespace concat {
22
23void get_sizes(const prb_t *prb, int64_t &outer_size, int64_t &inner_size,
24 int64_t &axis_size) {
25 outer_size = inner_size = 1;
26 for (int i = 0; i < prb->axis; i++)
27 outer_size *= prb->vdims[0][i];
28 for (int i = prb->axis + 1; i < prb->ndims; i++)
29 inner_size *= prb->vdims[0][i];
30 axis_size = prb->axis_size();
31}
32
33void compute_ref(
34 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
35 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
36
37 float *dst_ptr = (float *)dst;
38
39 int64_t outer_size {0}, inner_size {0}, axis_size {0};
40 get_sizes(prb, outer_size, inner_size, axis_size);
41
42 benchdnn_parallel_nd(outer_size, inner_size, [&](int64_t ou, int64_t in) {
43 int64_t off_dst = ou * axis_size * inner_size;
44 for (int i_input = 0; i_input < prb->n_inputs(); ++i_input) {
45 const dnn_mem_t &src_i = args.find(DNNL_ARG_MULTIPLE_SRC + i_input);
46 int64_t i_axis_size = prb->vdims[i_input][prb->axis];
47 int64_t off_src = ou * i_axis_size * inner_size;
48
49 float scale_i
50 = prb->attr.scales.get(DNNL_ARG_MULTIPLE_SRC + i_input)
51 .scale;
52
53 for (int64_t as = 0; as < i_axis_size; ++as) {
54 int64_t idx = as * inner_size + in;
55 dst_ptr[off_dst + idx]
56 = src_i.get_elem(off_src + idx) * scale_i;
57 }
58 // the next input start point
59 off_dst += i_axis_size * inner_size;
60 }
61 });
62}
63
64} // namespace concat
65