1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cstring> |
18 | |
19 | #include "common/dnnl_thread.hpp" |
20 | |
21 | #include "cpu/simple_concat.hpp" |
22 | |
23 | namespace dnnl { |
24 | namespace impl { |
25 | namespace cpu { |
26 | |
27 | using namespace memory_tracking::names; |
28 | |
29 | template <data_type_t data_type> |
30 | status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const { |
31 | auto scratchpad = ctx.get_scratchpad_grantor(); |
32 | auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs); |
33 | auto optrs = scratchpad.template get<data_t *>(key_concat_optrs); |
34 | auto nelems_to_copy = scratchpad.template get<dim_t>(key_concat_nelems); |
35 | auto is = scratchpad.template get<strides_t>(key_concat_istrides); |
36 | |
37 | const int num_arrs = pd()->n_inputs(); |
38 | const int *perm = pd()->perm_, *iperm = pd()->iperm_; |
39 | const int concat_dim = pd()->concat_dim(); |
40 | auto o_base_ptr = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); |
41 | if (o_base_ptr == nullptr) return status::success; |
42 | |
43 | for (int a = 0; a < num_arrs; ++a) { |
44 | const memory_desc_wrapper i_d(pd()->src_md(a)); |
45 | const memory_desc_wrapper o_d(pd()->src_image_md(a)); |
46 | const auto iptr = CTX_IN_MEM(const data_t *, DNNL_ARG_MULTIPLE_SRC + a); |
47 | if (iptr == nullptr) { |
48 | iptrs[a] = nullptr; |
49 | nelems_to_copy[a] = 0; |
50 | continue; |
51 | } |
52 | iptrs[a] = iptr + i_d.blk_off(0); |
53 | optrs[a] = o_base_ptr + o_d.blk_off(0); |
54 | nelems_to_copy[a] = pd()->nelems_to_concat(i_d); |
55 | for (int i = 0; i < DNNL_MAX_NDIMS; i++) { |
56 | if (i < perm[concat_dim]) |
57 | is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]); |
58 | else |
59 | is[a][i] = 0; |
60 | } |
61 | } |
62 | |
63 | const memory_desc_wrapper o_d(pd()->dst_md(0)); |
64 | |
65 | strides_t os = {0}; |
66 | bool has_outer_loop = false; |
67 | for (int i = 0; i < perm[concat_dim]; i++) { |
68 | os[i] = o_d.blocking_desc().strides[iperm[i]]; |
69 | // CAVEAT: if this impl supports not matching stag and dtag, strides |
70 | // should be taken into account for this condition. |
71 | if (o_d.padded_dims()[iperm[i]] != 1) has_outer_loop = true; |
72 | } |
73 | |
74 | // Applies when concat axis is the outermost dimension, e.g. concat_axis = 0 |
75 | // or concat_axis = 1, and dims[0] = 1; |
76 | if (!has_outer_loop) { |
77 | int nthr = dnnl_get_max_threads(); |
78 | parallel(nthr, [&](int ithr, int nthr) { |
79 | for (int a = 0; a < num_arrs; ++a) { |
80 | dim_t start {0}, end {0}; |
81 | balance211(nelems_to_copy[a], nthr, ithr, start, end); |
82 | |
83 | const data_t *i = iptrs[a] + start; |
84 | data_t *o = optrs[a] + start; |
85 | |
86 | PRAGMA_OMP_SIMD() |
87 | for (dim_t e = 0; e < end - start; ++e) |
88 | o[e] = i[e]; |
89 | } |
90 | }); |
91 | |
92 | return status::success; |
93 | } |
94 | |
95 | dims_t phys_dims; |
96 | for (int i = 0; i < DNNL_MAX_NDIMS; i++) { |
97 | if (i < perm[concat_dim]) |
98 | phys_dims[i] |
99 | = o_d.padded_dims()[iperm[i]] / pd()->blocks_[iperm[i]]; |
100 | else |
101 | phys_dims[i] = 1; |
102 | } |
103 | |
104 | const auto L1_size = platform::get_per_core_cache_size(1); |
105 | UNUSED(L1_size); // for Windows |
106 | |
107 | parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3], |
108 | phys_dims[4], num_arrs, |
109 | [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, dim_t a) { |
110 | // check if zero memory |
111 | if (iptrs[a] == nullptr) return; |
112 | |
113 | // XXX: this code may access uninitialized values in is[*][0-4] -- |
114 | // that's why we have to set them to zero although this is |
115 | // probably benign |
116 | size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2 |
117 | + is[a][3] * n3 + is[a][4] * n4; |
118 | size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2 |
119 | + os[3] * n3 + os[4] * n4; |
120 | const data_t *i = &iptrs[a][in_off]; |
121 | data_t *o = &optrs[a][out_off]; |
122 | |
123 | #if defined(__GNUC__) |
124 | // Heuristic: |
125 | // memcpy works generally faster for data sizes not |
126 | // exceeding L1 cache. |
127 | if (nelems_to_copy[a] * sizeof(data_t) > L1_size) { |
128 | // The code below performs data copying: o[e] = i[e] |
129 | // and uses a workaround to make GNU compilers optimize it |
130 | uint8_t *ptro = reinterpret_cast<uint8_t *>(o); |
131 | const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i); |
132 | |
133 | const size_t head_part = sizeof(uint32_t) |
134 | - reinterpret_cast<uint64_t>(ptro) |
135 | % sizeof(uint32_t); |
136 | const size_t main_part |
137 | = (nelems_to_copy[a] - head_part / sizeof(data_t)) |
138 | * sizeof(data_t) / sizeof(uint32_t); |
139 | const size_t tail_part |
140 | = (nelems_to_copy[a] * sizeof(data_t)) - head_part |
141 | - (main_part * sizeof(uint32_t)); |
142 | for (size_t e = 0; e < head_part; ++e) { |
143 | *ptro = *ptri; |
144 | ++ptro; |
145 | ++ptri; |
146 | } |
147 | PRAGMA_OMP_SIMD() |
148 | for (size_t e = 0; e < main_part; ++e) { |
149 | *(reinterpret_cast<uint32_t *>(ptro)) |
150 | = *(reinterpret_cast<const uint32_t *>(ptri)); |
151 | ptro += sizeof(uint32_t); |
152 | ptri += sizeof(uint32_t); |
153 | } |
154 | for (size_t e = 0; e < tail_part; ++e) { |
155 | *ptro = *ptri; |
156 | ++ptro; |
157 | ++ptri; |
158 | } |
159 | } else { |
160 | std::memcpy(o, i, nelems_to_copy[a] * sizeof(data_t)); |
161 | } |
162 | #else |
163 | PRAGMA_OMP_SIMD() |
164 | for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e]; |
165 | #endif |
166 | }); |
167 | |
168 | return status::success; |
169 | } |
170 | |
171 | template struct simple_concat_t<data_type::f32>; |
172 | template struct simple_concat_t<data_type::u8>; |
173 | template struct simple_concat_t<data_type::s8>; |
174 | template struct simple_concat_t<data_type::s32>; |
175 | template struct simple_concat_t<data_type::bf16>; |
176 | template struct simple_concat_t<data_type::f16>; |
177 | |
178 | } // namespace cpu |
179 | } // namespace impl |
180 | } // namespace dnnl |
181 | |