1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cstring>
18
19#include "common/dnnl_thread.hpp"
20
21#include "cpu/simple_concat.hpp"
22
23namespace dnnl {
24namespace impl {
25namespace cpu {
26
27using namespace memory_tracking::names;
28
29template <data_type_t data_type>
30status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
31 auto scratchpad = ctx.get_scratchpad_grantor();
32 auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs);
33 auto optrs = scratchpad.template get<data_t *>(key_concat_optrs);
34 auto nelems_to_copy = scratchpad.template get<dim_t>(key_concat_nelems);
35 auto is = scratchpad.template get<strides_t>(key_concat_istrides);
36
37 const int num_arrs = pd()->n_inputs();
38 const int *perm = pd()->perm_, *iperm = pd()->iperm_;
39 const int concat_dim = pd()->concat_dim();
40 auto o_base_ptr = CTX_OUT_MEM(data_t *, DNNL_ARG_DST);
41 if (o_base_ptr == nullptr) return status::success;
42
43 for (int a = 0; a < num_arrs; ++a) {
44 const memory_desc_wrapper i_d(pd()->src_md(a));
45 const memory_desc_wrapper o_d(pd()->src_image_md(a));
46 const auto iptr = CTX_IN_MEM(const data_t *, DNNL_ARG_MULTIPLE_SRC + a);
47 if (iptr == nullptr) {
48 iptrs[a] = nullptr;
49 nelems_to_copy[a] = 0;
50 continue;
51 }
52 iptrs[a] = iptr + i_d.blk_off(0);
53 optrs[a] = o_base_ptr + o_d.blk_off(0);
54 nelems_to_copy[a] = pd()->nelems_to_concat(i_d);
55 for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
56 if (i < perm[concat_dim])
57 is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]);
58 else
59 is[a][i] = 0;
60 }
61 }
62
63 const memory_desc_wrapper o_d(pd()->dst_md(0));
64
65 strides_t os = {0};
66 bool has_outer_loop = false;
67 for (int i = 0; i < perm[concat_dim]; i++) {
68 os[i] = o_d.blocking_desc().strides[iperm[i]];
69 // CAVEAT: if this impl supports not matching stag and dtag, strides
70 // should be taken into account for this condition.
71 if (o_d.padded_dims()[iperm[i]] != 1) has_outer_loop = true;
72 }
73
74 // Applies when concat axis is the outermost dimension, e.g. concat_axis = 0
75 // or concat_axis = 1, and dims[0] = 1;
76 if (!has_outer_loop) {
77 int nthr = dnnl_get_max_threads();
78 parallel(nthr, [&](int ithr, int nthr) {
79 for (int a = 0; a < num_arrs; ++a) {
80 dim_t start {0}, end {0};
81 balance211(nelems_to_copy[a], nthr, ithr, start, end);
82
83 const data_t *i = iptrs[a] + start;
84 data_t *o = optrs[a] + start;
85
86 PRAGMA_OMP_SIMD()
87 for (dim_t e = 0; e < end - start; ++e)
88 o[e] = i[e];
89 }
90 });
91
92 return status::success;
93 }
94
95 dims_t phys_dims;
96 for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
97 if (i < perm[concat_dim])
98 phys_dims[i]
99 = o_d.padded_dims()[iperm[i]] / pd()->blocks_[iperm[i]];
100 else
101 phys_dims[i] = 1;
102 }
103
104 const auto L1_size = platform::get_per_core_cache_size(1);
105 UNUSED(L1_size); // for Windows
106
107 parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
108 phys_dims[4], num_arrs,
109 [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) {
110 // check if zero memory
111 if (iptrs[a] == nullptr) return;
112
113 // XXX: this code may access uninitialized values in is[*][0-4] --
114 // that's why we have to set them to zero although this is
115 // probably benign
116 size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2
117 + is[a][3] * n3 + is[a][4] * n4;
118 size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2
119 + os[3] * n3 + os[4] * n4;
120 const data_t *i = &iptrs[a][in_off];
121 data_t *o = &optrs[a][out_off];
122
123#if defined(__GNUC__)
124 // Heuristic:
125 // memcpy works generally faster for data sizes not
126 // exceeding L1 cache.
127 if (nelems_to_copy[a] * sizeof(data_t) > L1_size) {
128 // The code below performs data copying: o[e] = i[e]
129 // and uses a workaround to make GNU compilers optimize it
130 uint8_t *ptro = reinterpret_cast<uint8_t *>(o);
131 const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i);
132
133 const size_t head_part = sizeof(uint32_t)
134 - reinterpret_cast<uint64_t>(ptro)
135 % sizeof(uint32_t);
136 const size_t main_part
137 = (nelems_to_copy[a] - head_part / sizeof(data_t))
138 * sizeof(data_t) / sizeof(uint32_t);
139 const size_t tail_part
140 = (nelems_to_copy[a] * sizeof(data_t)) - head_part
141 - (main_part * sizeof(uint32_t));
142 for (size_t e = 0; e < head_part; ++e) {
143 *ptro = *ptri;
144 ++ptro;
145 ++ptri;
146 }
147 PRAGMA_OMP_SIMD()
148 for (size_t e = 0; e < main_part; ++e) {
149 *(reinterpret_cast<uint32_t *>(ptro))
150 = *(reinterpret_cast<const uint32_t *>(ptri));
151 ptro += sizeof(uint32_t);
152 ptri += sizeof(uint32_t);
153 }
154 for (size_t e = 0; e < tail_part; ++e) {
155 *ptro = *ptri;
156 ++ptro;
157 ++ptri;
158 }
159 } else {
160 std::memcpy(o, i, nelems_to_copy[a] * sizeof(data_t));
161 }
162#else
163 PRAGMA_OMP_SIMD()
164 for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e];
165#endif
166 });
167
168 return status::success;
169}
170
171template struct simple_concat_t<data_type::f32>;
172template struct simple_concat_t<data_type::u8>;
173template struct simple_concat_t<data_type::s8>;
174template struct simple_concat_t<data_type::s32>;
175template struct simple_concat_t<data_type::bf16>;
176template struct simple_concat_t<data_type::f16>;
177
178} // namespace cpu
179} // namespace impl
180} // namespace dnnl
181