1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cassert>
18
19#include "dnnl_thread.hpp"
20#include "dnnl_traits.hpp"
21#include "stream.hpp"
22#include "type_helpers.hpp"
23#include "utils.hpp"
24
25#include "memory.hpp"
26#include "primitive_exec_types.hpp"
27
28using namespace dnnl::impl;
29using namespace dnnl::impl::data_type;
30using namespace dnnl::impl::status;
31
32enum blk_kind_t { a, b, c, ab, ba, bc, cb };
33
34template <data_type_t dt, blk_kind_t blk_kind, int blksize>
35void typed_zero_pad_blk(const memory_desc_wrapper &m_d, void *data_handle) {
36 /* Note: for bf16 memory,
37 * use uint16_t for initialization of padding to zero,
38 * in order to avoid using assign operators defined in bfloat16_t.
39 * This allows user will be to create bf16 memory
40 * on non-avx512_core machines. */
41 using data_t = typename utils::conditional<dt == bf16, uint16_t,
42 typename prec_traits<dt>::type>::type;
43 auto data = reinterpret_cast<data_t *>(data_handle);
44 const auto &dims = m_d.dims();
45 const auto &pdims = m_d.padded_dims();
46 const auto &blk = m_d.blocking_desc();
47 auto dim_is_blocked = [&](int dim) {
48 for (int i = 0; i < blk.inner_nblks; i++)
49 if (blk.inner_idxs[i] == dim) return true;
50 return false;
51 };
52 bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1),
53 C_blocked = dim_is_blocked(2);
54
55 assert(blk.inner_nblks < 4);
56 assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked)
57 || (C_blocked && B_blocked));
58
59 const int a_tail_s = A_blocked ? dims[0] % blksize : 0;
60 const int b_tail_s = B_blocked ? dims[1] % blksize : 0;
61 const int c_tail_s = C_blocked ? dims[2] % blksize : 0;
62 assert(a_tail_s || b_tail_s || c_tail_s);
63
64 const int ndims = m_d.ndims();
65 assert(1 <= ndims && ndims <= 6);
66 const dim_t A = A_blocked ? pdims[0] / blksize : dims[0];
67 const dim_t B = ndims <= 1 ? 1 : B_blocked ? pdims[1] / blksize : dims[1];
68 const dim_t C = ndims <= 2 ? 1 : C_blocked ? pdims[2] / blksize : dims[2];
69 const dim_t D = ndims <= 3 ? 1 : dims[3];
70 const dim_t E = ndims <= 4 ? 1 : dims[4];
71 const dim_t F = ndims <= 5 ? 1 : dims[5];
72 const dim_t inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1;
73
74 auto zeroize_tail = [&](data_t *d, const int tail_s) {
75 for (int b = tail_s; b < blksize; ++b)
76 d[b] = 0;
77 };
78 auto zeroize_tail_inner = [&](data_t *d, const int tail_s) {
79 for (int b1 = 0; b1 < blksize; ++b1)
80 for (int b2 = tail_s; b2 < blksize; ++b2)
81 d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
82 + b1 % inner_blk]
83 = 0;
84 };
85 auto zeroize_tail_outer = [&](data_t *d, const int tail_s) {
86 for (int b1 = tail_s; b1 < blksize; ++b1)
87 for (int b2 = 0; b2 < blksize; ++b2)
88 d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
89 + b1 % inner_blk]
90 = 0;
91 };
92
93 if (c_tail_s) {
94 parallel_nd(A, B, D, E, F,
95 [&](dim_t a, dim_t b, dim_t d, dim_t e, dim_t f) {
96 auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)];
97 if (blk_kind == c)
98 zeroize_tail(x, c_tail_s);
99 else if (blk_kind == bc)
100 zeroize_tail_inner(x, c_tail_s);
101 else if (blk_kind == cb)
102 zeroize_tail_outer(x, c_tail_s);
103 });
104 }
105
106 if (b_tail_s) {
107 parallel_nd(A, C, D, E, F,
108 [&](dim_t a, dim_t c, dim_t d, dim_t e, dim_t f) {
109 auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)];
110 if (blk_kind == b)
111 zeroize_tail(x, b_tail_s);
112 else if (blk_kind == ab || blk_kind == cb)
113 zeroize_tail_inner(x, b_tail_s);
114 else if (blk_kind == ba || blk_kind == bc)
115 zeroize_tail_outer(x, b_tail_s);
116 });
117 }
118
119 if (a_tail_s) {
120 parallel_nd(B, C, D, E, F,
121 [&](dim_t b, dim_t c, dim_t d, dim_t e, dim_t f) {
122 auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)];
123 if (blk_kind == a)
124 zeroize_tail(x, a_tail_s);
125 else if (blk_kind == ba)
126 zeroize_tail_inner(x, a_tail_s);
127 else if (blk_kind == ab)
128 zeroize_tail_outer(x, a_tail_s);
129 });
130 }
131}
132
133/*
134 * all
135 */
136template <data_type_t dt>
137void typed_zero_pad_generic_blocked(
138 const memory_desc_wrapper &m_d, void *data_handle) {
139 /* Note: for bf16 memory,
140 * use uint16_t for initialization of padding to zero,
141 * in order to avoid using assign operators defined in bfloat16_t.
142 * This allows user will be to create bf16 memory
143 * on non-avx512_core machines. */
144 using data_t = typename utils::conditional<dt == bf16, uint16_t,
145 typename prec_traits<dt>::type>::type;
146 auto data = reinterpret_cast<data_t *>(data_handle);
147 const int ndims = m_d.ndims();
148 const auto &dims = m_d.dims();
149 const auto &pdims = m_d.padded_dims();
150
151 const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
152
153 /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
154 * | \ /
155 * | ---------------------
156 * has contiguous
157 * padding
158 *
159 * step <-- D_k+1 * ... * D_ndims-1
160 * step_dim <-- k
161 */
162
163 ptrdiff_t step = 1;
164 int step_dim = ndims - 1;
165 for (; step_dim >= 0; --step_dim) {
166 if (dims[step_dim] != pdims[step_dim]) break;
167 step *= dims[step_dim];
168 }
169
170 assert(step_dim >= 0 && "no zero padding is required");
171 if (step_dim < 0) return;
172
173 parallel_nd(nelems / step, [&](ptrdiff_t e1) {
174 bool need_zero = false;
175
176 ptrdiff_t idx = e1;
177 for (int d = step_dim; d >= 0; --d) {
178 if (idx % pdims[d] >= dims[d]) {
179 need_zero = true;
180 break;
181 }
182 idx /= pdims[d];
183 }
184
185 if (need_zero) {
186 for (ptrdiff_t e0 = 0; e0 < step; ++e0)
187 data[m_d.off_l(e1 * step + e0, true)] = 0;
188 }
189 });
190}
191
192template <data_type_t dt>
193status_t typed_zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
194 const memory_desc_wrapper mdw(memory->md());
195 memory_storage_t *memory_storage = memory->memory_storage();
196
197 if (mdw.format_kind() != format_kind::blocked) return unimplemented;
198
199 if (mdw.nelems(false) == mdw.nelems(true)) return success;
200
201 const size_t map_size = mdw.size();
202 assert(map_size != DNNL_RUNTIME_SIZE_VAL);
203
204 void *mapped_ptr
205 = ctx.map_memory_storage(memory_storage, ctx.stream(), map_size);
206
207 auto *data = static_cast<typename prec_traits<dt>::type *>(mapped_ptr);
208 auto blk = mdw.blocking_desc();
209
210 auto get_blksize = [&](int ind) {
211 int blksize = 1;
212 for (int i = 0; i < blk.inner_nblks; i++) {
213 if (blk.inner_idxs[i] == ind) blksize *= blk.inner_blks[i];
214 }
215 return blksize;
216 };
217 const int blksize = get_blksize(blk.inner_idxs[0]);
218
219#define CASE(blksize_, blk_kind) \
220 do { \
221 if (blksize == (blksize_)) { \
222 typed_zero_pad_blk<dt, blk_kind, blksize_>(mdw, data); \
223 ctx.unmap_memory_storage( \
224 memory_storage, mapped_ptr, ctx.stream()); \
225 return success; \
226 } \
227 } while (0)
228
229 switch (blk.inner_nblks) {
230 case 1:
231 if (blk.inner_idxs[0] == 0) {
232 CASE(4, a);
233 CASE(8, a);
234 CASE(16, a);
235 } else if (blk.inner_idxs[0] == 1) {
236 CASE(4, b);
237 CASE(8, b);
238 CASE(16, b);
239 }
240 break;
241 case 2:
242 case 3:
243 if (blk.inner_nblks == 3 && blk.inner_idxs[0] != blk.inner_idxs[2])
244 break;
245 if (blksize != get_blksize(blk.inner_idxs[1])) break;
246
247 if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
248 CASE(4, ab);
249 CASE(8, ab);
250 CASE(16, ab);
251 } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) {
252 CASE(4, ba);
253 CASE(8, ba);
254 CASE(16, ba);
255 }
256 if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
257 CASE(4, bc);
258 CASE(8, bc);
259 CASE(16, bc);
260 } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) {
261 CASE(4, cb);
262 CASE(8, cb);
263 CASE(16, cb);
264 }
265 break;
266 default: break;
267 }
268
269#undef CASE
270
271 // the last line of defence
272 typed_zero_pad_generic_blocked<dt>(mdw, data);
273
274 ctx.unmap_memory_storage(memory_storage, mapped_ptr, ctx.stream());
275 return success;
276}
277
278static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
279 memory_desc_wrapper mdw(memory->md());
280 switch (mdw.data_type()) {
281 case f16: return typed_zero_pad<f16>(memory, ctx);
282 case bf16: return typed_zero_pad<bf16>(memory, ctx);
283 case f32: return typed_zero_pad<f32>(memory, ctx);
284 case s32: return typed_zero_pad<s32>(memory, ctx);
285 case s8: return typed_zero_pad<s8>(memory, ctx);
286 case u8: return typed_zero_pad<u8>(memory, ctx);
287 default: assert(!"memory is undefined"); return unimplemented;
288 }
289 return unimplemented;
290}
291
292status_t stream_t::zero_pad(const memory_t *memory, const exec_ctx_t &ctx) {
293 return ::zero_pad(memory, ctx);
294}
295
296status_t memory_t::zero_pad(const exec_ctx_t &ctx) const {
297 memory_desc_wrapper mdw(md());
298 const bool skip_zeroing = false || memory_storage()->is_null()
299 || mdw.is_zero() || !mdw.is_blocking_desc();
300 if (skip_zeroing) return success;
301
302 stream_t *stream = ctx.stream();
303 status_t status;
304 if (stream == nullptr) {
305 engine_t *engine;
306 engine = memory_storage()->engine();
307 CHECK(engine->get_service_stream(stream));
308 }
309
310 if (stream != nullptr)
311 status = stream->zero_pad(this, ctx);
312 else
313 status = ::zero_pad(this, ctx);
314
315 return status;
316}
317
318extern "C" dnnl_status_t DNNL_API dnnl_impl_zero_pad(
319 const memory_t *memory, stream_t *stream) {
320 if (memory == nullptr || stream == nullptr)
321 return status::invalid_arguments;
322 memory_arg_t mem_arg = {const_cast<memory_t *>(memory), true};
323 exec_args_t args = {{0, mem_arg}};
324 return memory->zero_pad(exec_ctx_t(stream, std::move(args)));
325}
326