1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <cctype> |
18 | |
19 | #include "oneapi/dnnl/dnnl.hpp" |
20 | |
21 | #include "common/c_types_map.hpp" |
22 | #include "common/memory_desc.hpp" |
23 | #include "common/memory_desc_wrapper.hpp" |
24 | #include "common/type_helpers.hpp" |
25 | #include "common/utils.hpp" |
26 | |
27 | using namespace dnnl::impl; |
28 | using namespace dnnl::impl::status; |
29 | using namespace dnnl::impl::utils; |
30 | |
31 | namespace dnnl { |
32 | namespace impl { |
33 | |
34 | status_t memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims, |
35 | const dims_t dims, data_type_t data_type, format_tag_t tag) { |
36 | if (ndims == 0 || tag == format_tag::undef) { |
37 | memory_desc = types::zero_md(); |
38 | return success; |
39 | } |
40 | |
41 | format_kind_t format_kind = types::format_tag_to_kind(tag); |
42 | |
43 | /* memory_desc != 0 */ |
44 | bool args_ok |
45 | = memory_desc_sanity_check(ndims, dims, data_type, format_kind); |
46 | if (!args_ok) return invalid_arguments; |
47 | |
48 | auto md = memory_desc_t(); |
49 | md.ndims = ndims; |
50 | array_copy(md.dims, dims, ndims); |
51 | md.data_type = data_type; |
52 | array_copy(md.padded_dims, dims, ndims); |
53 | md.format_kind = format_kind; |
54 | |
55 | status_t status = success; |
56 | if (tag == format_tag::undef) { |
57 | status = invalid_arguments; |
58 | } else if (tag == format_tag::any) { |
59 | // nop |
60 | } else if (format_kind == format_kind::blocked) { |
61 | status = memory_desc_wrapper::compute_blocking(md, tag); |
62 | } else { |
63 | assert(!"unreachable" ); |
64 | status = invalid_arguments; |
65 | } |
66 | |
67 | if (status == success) memory_desc = md; |
68 | |
69 | return status; |
70 | } |
71 | |
72 | status_t memory_desc_init_by_strides(memory_desc_t &memory_desc, int ndims, |
73 | const dims_t dims, data_type_t data_type, const dims_t strides) { |
74 | if (ndims == 0) { |
75 | memory_desc = types::zero_md(); |
76 | return success; |
77 | } |
78 | |
79 | /* memory_desc != 0 */ |
80 | bool args_ok = memory_desc_sanity_check( |
81 | ndims, dims, data_type, format_kind::undef); |
82 | if (!args_ok) return invalid_arguments; |
83 | |
84 | auto md = memory_desc_t(); |
85 | md.ndims = ndims; |
86 | array_copy(md.dims, dims, ndims); |
87 | md.data_type = data_type; |
88 | array_copy(md.padded_dims, dims, ndims); |
89 | md.format_kind = format_kind::blocked; |
90 | |
91 | dims_t default_strides = {0}; |
92 | if (strides == nullptr) { |
93 | bool has_runtime_strides = false; |
94 | default_strides[md.ndims - 1] = 1; |
95 | for (int d = md.ndims - 2; d >= 0; --d) { |
96 | if (md.padded_dims[d] == DNNL_RUNTIME_DIM_VAL) |
97 | has_runtime_strides = true; |
98 | default_strides[d] = has_runtime_strides |
99 | ? DNNL_RUNTIME_DIM_VAL |
100 | : default_strides[d + 1] * md.padded_dims[d + 1]; |
101 | } |
102 | strides = default_strides; |
103 | } |
104 | if (!memory_desc_strides_check(md, strides)) return invalid_arguments; |
105 | |
106 | array_copy(md.format_desc.blocking.strides, strides, md.ndims); |
107 | |
108 | memory_desc = md; |
109 | |
110 | return success; |
111 | } |
112 | |
113 | status_t memory_desc_init_submemory(memory_desc_t &memory_desc, |
114 | const memory_desc_t &parent_memory_desc, const dims_t dims, |
115 | const dims_t offsets) { |
116 | if (!memory_desc_sanity_check(parent_memory_desc)) return invalid_arguments; |
117 | |
118 | const memory_desc_wrapper src_d(parent_memory_desc); |
119 | if (src_d.has_runtime_dims_or_strides()) return unimplemented; |
120 | |
121 | for (int d = 0; d < src_d.ndims(); ++d) { |
122 | if (utils::one_of(DNNL_RUNTIME_DIM_VAL, dims[d], offsets[d])) |
123 | return unimplemented; |
124 | |
125 | if (dims[d] < 0 || offsets[d] < 0 |
126 | || (offsets[d] + dims[d] > src_d.dims()[d])) |
127 | return invalid_arguments; |
128 | } |
129 | |
130 | if (src_d.format_kind() != format_kind::blocked) return unimplemented; |
131 | |
132 | dims_t blocks; |
133 | src_d.compute_blocks(blocks); |
134 | |
135 | memory_desc_t dst_d = parent_memory_desc; |
136 | auto &dst_d_blk = dst_d.format_desc.blocking; |
137 | |
138 | /* TODO: put this into memory_desc_wrapper */ |
139 | for (int d = 0; d < src_d.ndims(); ++d) { |
140 | const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d]; |
141 | |
142 | /* very limited functionality for now */ |
143 | const bool ok = offsets[d] % blocks[d] == 0 /* [r1] */ |
144 | && src_d.padded_offsets()[d] == 0 |
145 | && IMPLICATION(!is_right_border, |
146 | (dims[d] % blocks[d] == 0 || dims[d] < blocks[d])); |
147 | if (!ok) return unimplemented; |
148 | |
149 | dst_d.dims[d] = dims[d]; |
150 | dst_d.padded_dims[d] = is_right_border |
151 | ? src_d.padded_dims()[d] - offsets[d] |
152 | : dst_d.dims[d]; |
153 | dst_d.padded_offsets[d] = src_d.padded_offsets()[d]; |
154 | dst_d.offset0 += /* [r1] */ |
155 | offsets[d] / blocks[d] * dst_d_blk.strides[d]; |
156 | } |
157 | |
158 | memory_desc = dst_d; |
159 | |
160 | return success; |
161 | } |
162 | |
163 | status_t memory_desc_reshape(memory_desc_t &out_memory_desc, |
164 | const memory_desc_t &in_memory_desc, int ndims, const dims_t dims) { |
165 | auto volume = [](const dim_t *dims, int ndims) -> dim_t { |
166 | dim_t prod = 1; |
167 | for (int i = 0; i < ndims; ++i) { |
168 | if (dims[i] == DNNL_RUNTIME_DIM_VAL) return DNNL_RUNTIME_DIM_VAL; |
169 | prod *= dims[i] > 0 ? dims[i] : 1; |
170 | } |
171 | return prod; |
172 | }; |
173 | |
174 | if (!memory_desc_sanity_check(in_memory_desc) |
175 | || !memory_desc_sanity_check(ndims, dims, in_memory_desc.data_type, |
176 | in_memory_desc.format_kind) |
177 | || !one_of(in_memory_desc.format_kind, format_kind::any, |
178 | format_kind::blocked) |
179 | || types::is_zero_md(&in_memory_desc) |
180 | || volume(in_memory_desc.dims, in_memory_desc.ndims) |
181 | != volume(dims, ndims) |
182 | || memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides() |
183 | || in_memory_desc.extra.flags != 0) |
184 | return invalid_arguments; |
185 | |
186 | if (in_memory_desc.format_kind == format_kind::any) |
187 | return memory_desc_init_by_tag(out_memory_desc, ndims, dims, |
188 | in_memory_desc.data_type, format_tag::any); |
189 | |
190 | assert(in_memory_desc.format_kind == format_kind::blocked); |
191 | assert(in_memory_desc.extra.flags == 0); |
192 | |
193 | // temporary output |
194 | auto md = in_memory_desc; |
195 | |
196 | md.ndims = ndims; |
197 | array_copy(md.dims, dims, md.ndims); |
198 | |
199 | const int i_ndims = in_memory_desc.ndims; |
200 | const int o_ndims = md.ndims; |
201 | |
202 | const auto &i_dims = in_memory_desc.dims, |
203 | &i_pdims = in_memory_desc.padded_dims; |
204 | const auto &o_dims = md.dims; |
205 | |
206 | const auto &i_bd = in_memory_desc.format_desc.blocking; |
207 | auto &o_bd = md.format_desc.blocking; |
208 | |
209 | dims_t blocks = {0}; |
210 | memory_desc_wrapper(in_memory_desc).compute_blocks(blocks); |
211 | |
212 | enum class action_t { REMOVE_1, ADD_1, KEEP_DIM, REARRANGE_DIMS, FAIL }; |
213 | |
214 | // Determine groups in input and output dims starting with the given |
215 | // positions (going backwards) that satisfy one of the conditions: |
216 | // - REMOVE_1 |
217 | // input_group = {1}, output_group = empty |
218 | // - ADD_1 |
219 | // input_group = empty, output_group = {1} |
220 | // - KEEP_DIM |
221 | // input_group = {x}, output_group = {x} |
222 | // - REARRANGE_DIMS |
223 | // input_group = {x1, x2, .., xk}, output_group = {y1, y2, ..., ym} |
224 | // and product(x_i) = product(y_j), and the groups are minimal |
225 | // - FAIL |
226 | // invalid configuration (return false) |
227 | auto find_groups |
228 | = [&](int &i_group_begin, int i_group_end, int &o_group_begin, |
229 | int o_group_end) -> action_t { |
230 | // 1st step: check for `1` in the input dims |
231 | if (i_group_end > 0 && i_dims[i_group_end - 1] == 1) { |
232 | i_group_begin = i_group_end - 1; |
233 | if (i_pdims[i_group_end - 1] == 1) { |
234 | o_group_begin = o_group_end; |
235 | return action_t::REMOVE_1; |
236 | } else if (o_group_end > 0 && o_dims[o_group_end - 1] == 1) { |
237 | o_group_begin = o_group_end - 1; |
238 | return action_t::KEEP_DIM; |
239 | } else { |
240 | return action_t::FAIL; |
241 | } |
242 | } |
243 | |
244 | // 2nd step: check for `1` in the output dims |
245 | if (o_group_end > 0 && o_dims[o_group_end - 1] == 1) { |
246 | i_group_begin = i_group_end; |
247 | o_group_begin = o_group_end - 1; |
248 | return action_t::ADD_1; |
249 | } |
250 | |
251 | // at this moment both groups cannot be empty |
252 | if (i_group_end == 0 || o_group_end == 0) return action_t::FAIL; |
253 | |
254 | // 3rd step: find the non-trivial groups of the same volume |
255 | i_group_begin = i_group_end - 1; |
256 | o_group_begin = o_group_end - 1; |
257 | |
258 | dim_t i_volume = i_dims[i_group_begin]; |
259 | dim_t o_volume = o_dims[o_group_begin]; |
260 | |
261 | while (i_volume != o_volume) { |
262 | if (i_volume < o_volume) { |
263 | if (i_group_begin == 0) return action_t::FAIL; |
264 | i_volume *= i_dims[--i_group_begin]; |
265 | |
266 | // do not allow `0` axis in the middle |
267 | if (i_volume == 0) return action_t::FAIL; |
268 | } else { |
269 | if (o_group_begin == 0) return action_t::FAIL; |
270 | o_volume *= o_dims[--o_group_begin]; |
271 | |
272 | // do not allow `0` axis in the middle |
273 | if (o_volume == 0) return action_t::FAIL; |
274 | } |
275 | } |
276 | |
277 | assert(i_volume == o_volume); |
278 | assert(i_group_begin >= 0); |
279 | assert(o_group_begin >= 0); |
280 | |
281 | return (i_group_begin + 1 == i_group_end |
282 | && o_group_begin + 1 == o_group_end) |
283 | ? action_t::KEEP_DIM |
284 | : action_t::REARRANGE_DIMS; |
285 | }; |
286 | |
287 | int i_group_begin = i_ndims, i_group_end = i_ndims; |
288 | int o_group_begin = o_ndims, o_group_end = o_ndims; |
289 | |
290 | while (i_group_end != 0 || o_group_end != 0) { |
291 | action_t action = find_groups( |
292 | i_group_begin, i_group_end, o_group_begin, o_group_end); |
293 | |
294 | if (action == action_t::REMOVE_1) { |
295 | // nop, padding is already taken into account by `find_groups()` |
296 | } else if (action == action_t::ADD_1) { |
297 | // get the stride from the right |
298 | dim_t current_stride = 1; |
299 | if (i_group_begin == i_ndims) { |
300 | for (int d = 0; d < i_bd.inner_nblks; ++d) |
301 | current_stride *= i_bd.inner_blks[d]; |
302 | } else { |
303 | // Add `1` to the left from axes with index `i_group_begin` |
304 | current_stride |
305 | = i_bd.strides[i_group_begin] * i_dims[i_group_begin]; |
306 | for (int d = 0; d < i_bd.inner_nblks; ++d) |
307 | if (i_bd.inner_idxs[d] == i_group_begin) |
308 | current_stride /= i_bd.inner_blks[d]; |
309 | } |
310 | md.padded_dims[o_group_begin] = 1; |
311 | md.padded_offsets[o_group_begin] = 0; |
312 | o_bd.strides[o_group_begin] = current_stride; |
313 | } else if (action == action_t::KEEP_DIM) { |
314 | // change the axis index from `i_group_begin` to `o_group_begin` |
315 | assert(i_group_begin + 1 == i_group_end); |
316 | assert(o_group_begin + 1 == o_group_end); |
317 | |
318 | md.padded_dims[o_group_begin] |
319 | = in_memory_desc.padded_dims[i_group_begin]; |
320 | md.padded_offsets[o_group_begin] |
321 | = in_memory_desc.padded_offsets[i_group_begin]; |
322 | o_bd.strides[o_group_begin] = i_bd.strides[i_group_begin]; |
323 | for (int d = 0; d < i_bd.inner_nblks; ++d) |
324 | if (i_bd.inner_idxs[d] == i_group_begin) |
325 | o_bd.inner_idxs[d] = o_group_begin; |
326 | } else if (action == action_t::REARRANGE_DIMS) { |
327 | // check that input group is dense, sequential, and is not blocked |
328 | for (int d = i_group_end - 1; d > i_group_begin; --d) |
329 | if (i_dims[d] * i_bd.strides[d] != i_bd.strides[d - 1]) |
330 | return invalid_arguments; |
331 | |
332 | // checked (i_group_begin, i_group_end), `i_group_begin` remains |
333 | for (int d = 0; d < i_bd.inner_nblks; ++d) |
334 | if (i_bd.inner_idxs[d] == i_group_begin) |
335 | return invalid_arguments; |
336 | if (in_memory_desc.padded_dims[i_group_begin] |
337 | != i_dims[i_group_begin]) |
338 | return invalid_arguments; |
339 | if (in_memory_desc.padded_offsets[i_group_begin] != 0) |
340 | return invalid_arguments; |
341 | |
342 | // oK, fill output md according to |
343 | // o_dims[o_group_begin .. o_group_end] |
344 | |
345 | dim_t current_stride = i_bd.strides[i_group_end - 1]; |
346 | for (int d = o_group_end - 1; d >= o_group_begin; --d) { |
347 | md.padded_dims[d] = o_dims[d]; |
348 | md.padded_offsets[d] = 0; |
349 | o_bd.strides[d] = current_stride; |
350 | current_stride *= md.padded_dims[d]; |
351 | } |
352 | } else { |
353 | assert(action == action_t::FAIL); |
354 | return invalid_arguments; |
355 | } |
356 | |
357 | i_group_end = i_group_begin; |
358 | o_group_end = o_group_begin; |
359 | } |
360 | |
361 | out_memory_desc = md; |
362 | return success; |
363 | } |
364 | |
365 | status_t memory_desc_permute_axes(memory_desc_t &out_memory_desc, |
366 | const memory_desc_t &in_memory_desc, const int *perm) { |
367 | if (!memory_desc_sanity_check(in_memory_desc) |
368 | || !one_of(in_memory_desc.format_kind, format_kind::any, |
369 | format_kind::blocked) |
370 | || types::is_zero_md(&in_memory_desc) |
371 | || memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides() |
372 | || in_memory_desc.extra.flags != 0) |
373 | return invalid_arguments; |
374 | |
375 | // verify that perm is indeed a permutation of [0 .. ndims) |
376 | unsigned occurrence_mask = 0; |
377 | for (int d = 0; d < in_memory_desc.ndims; ++d) |
378 | if (0 <= perm[d] && perm[d] < in_memory_desc.ndims) |
379 | occurrence_mask |= (1u << perm[d]); |
380 | if (occurrence_mask + 1 != (1u << in_memory_desc.ndims)) |
381 | return invalid_arguments; |
382 | |
383 | out_memory_desc = in_memory_desc; |
384 | for (int d = 0; d < in_memory_desc.ndims; ++d) { |
385 | if (perm[d] == d) continue; |
386 | out_memory_desc.dims[perm[d]] = in_memory_desc.dims[d]; |
387 | out_memory_desc.padded_dims[perm[d]] = in_memory_desc.padded_dims[d]; |
388 | out_memory_desc.padded_offsets[perm[d]] |
389 | = in_memory_desc.padded_offsets[d]; |
390 | if (in_memory_desc.format_kind == format_kind::blocked) { |
391 | const auto &i_bd = in_memory_desc.format_desc.blocking; |
392 | auto &o_bd = out_memory_desc.format_desc.blocking; |
393 | |
394 | o_bd.strides[perm[d]] = i_bd.strides[d]; |
395 | for (int blk = 0; blk < i_bd.inner_nblks; ++blk) |
396 | if (i_bd.inner_idxs[blk] == d) o_bd.inner_idxs[blk] = perm[d]; |
397 | } |
398 | } |
399 | |
400 | return success; |
401 | } |
402 | |
403 | // This is only used by internal API that is used for testing only. |
404 | status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims, |
405 | const dims_t dims, data_type_t data_type, const std::string &tag) { |
406 | // Copy to temporary to handle dims == md->dims case. |
407 | dims_t tmp_dims; |
408 | std::copy(dims, dims + ndims, tmp_dims); |
409 | |
410 | md.ndims = ndims; |
411 | if (ndims < 0 || ndims > DNNL_MAX_NDIMS) return invalid_arguments; |
412 | |
413 | std::copy(tmp_dims, tmp_dims + ndims, md.dims); |
414 | md.data_type = data_type; |
415 | md.format_kind = format_kind::blocked; |
416 | |
417 | // Parse dimensions and their block sizes starting from the innermost one. |
418 | std::vector<std::pair<int, int>> dim_blocks; |
419 | int pos = (int)tag.size() - 1; |
420 | int ndims_from_tag = -1; |
421 | while (pos >= 0) { |
422 | int pos0 = pos; |
423 | |
424 | --pos; |
425 | while (pos >= 0 && std::isdigit(tag[pos])) |
426 | pos--; |
427 | |
428 | int dim_idx = std::tolower(tag[pos0]) - 'a'; |
429 | if (dim_idx >= ndims) return invalid_arguments; |
430 | ndims_from_tag = std::max(dim_idx + 1, ndims_from_tag); |
431 | int block_str_len = pos0 - pos - 1; |
432 | int block = (block_str_len == 0) |
433 | ? 1 |
434 | : std::stoi(tag.substr(pos + 1, block_str_len)); |
435 | dim_blocks.emplace_back(dim_idx, block); |
436 | } |
437 | if (ndims_from_tag != ndims) return invalid_arguments; |
438 | |
439 | auto &blk = md.format_desc.blocking; |
440 | |
441 | // Compute strides and fill inner block sizes/indices. |
442 | dim_t stride = 1; |
443 | dims_t full_inner_blks; |
444 | std::fill(full_inner_blks, full_inner_blks + ndims, 1); |
445 | for (auto &p : dim_blocks) { |
446 | int dim_idx = p.first; |
447 | int block = p.second; |
448 | if (block == 1) { |
449 | assert(blk.strides[dim_idx] == 0); |
450 | blk.strides[dim_idx] = stride; |
451 | |
452 | dim_t fib = full_inner_blks[dim_idx]; |
453 | dim_t padded_dim = md.dims[dim_idx] == DNNL_RUNTIME_DIM_VAL |
454 | ? DNNL_RUNTIME_DIM_VAL |
455 | : (md.dims[dim_idx] + fib - 1) / fib * fib; |
456 | md.padded_dims[dim_idx] = padded_dim; |
457 | if (padded_dim == DNNL_RUNTIME_DIM_VAL) |
458 | stride = DNNL_RUNTIME_DIM_VAL; |
459 | else |
460 | stride *= (padded_dim / fib); |
461 | } else { |
462 | full_inner_blks[dim_idx] *= block; |
463 | blk.inner_blks[blk.inner_nblks] = block; |
464 | blk.inner_idxs[blk.inner_nblks] = dim_idx; |
465 | blk.inner_nblks++; |
466 | stride *= block; |
467 | } |
468 | } |
469 | |
470 | // Inner block sizes/indices are stored from the outermost to the innermost |
471 | // so need to reverse them. |
472 | std::reverse(blk.inner_blks, blk.inner_blks + blk.inner_nblks); |
473 | std::reverse(blk.inner_idxs, blk.inner_idxs + blk.inner_nblks); |
474 | |
475 | return success; |
476 | } |
477 | |
478 | } // namespace impl |
479 | } // namespace dnnl |
480 | |
481 | // API |
482 | status_t dnnl_memory_desc_create_with_tag(memory_desc_t **memory_desc, |
483 | int ndims, const dims_t dims, data_type_t data_type, format_tag_t tag) { |
484 | if (any_null(memory_desc)) return invalid_arguments; |
485 | |
486 | auto md = utils::make_unique<memory_desc_t>(); |
487 | if (!md) return out_of_memory; |
488 | CHECK(memory_desc_init_by_tag(*md, ndims, dims, data_type, tag)); |
489 | (*memory_desc) = md.release(); |
490 | return success; |
491 | } |
492 | |
493 | status_t dnnl_memory_desc_create_with_strides(memory_desc_t **memory_desc, |
494 | int ndims, const dims_t dims, data_type_t data_type, |
495 | const dims_t strides) { |
496 | if (any_null(memory_desc)) return invalid_arguments; |
497 | |
498 | auto md = utils::make_unique<memory_desc_t>(); |
499 | if (!md) return out_of_memory; |
500 | CHECK(memory_desc_init_by_strides(*md, ndims, dims, data_type, strides)); |
501 | (*memory_desc) = md.release(); |
502 | return success; |
503 | } |
504 | |
505 | status_t dnnl_memory_desc_create_submemory(memory_desc_t **memory_desc, |
506 | const memory_desc_t *parent_memory_desc, const dims_t dims, |
507 | const dims_t offsets) { |
508 | if (any_null(memory_desc, parent_memory_desc)) return invalid_arguments; |
509 | |
510 | auto md = utils::make_unique<memory_desc_t>(); |
511 | if (!md) return out_of_memory; |
512 | CHECK(memory_desc_init_submemory(*md, *parent_memory_desc, dims, offsets)); |
513 | (*memory_desc) = md.release(); |
514 | return success; |
515 | } |
516 | |
517 | status_t dnnl_memory_desc_reshape(memory_desc_t **out_memory_desc, |
518 | const memory_desc_t *in_memory_desc, int ndims, const dims_t dims) { |
519 | if (any_null(out_memory_desc, in_memory_desc)) return invalid_arguments; |
520 | |
521 | auto md = utils::make_unique<memory_desc_t>(); |
522 | if (!md) return out_of_memory; |
523 | CHECK(memory_desc_reshape(*md, *in_memory_desc, ndims, dims)); |
524 | (*out_memory_desc) = md.release(); |
525 | return success; |
526 | } |
527 | |
528 | status_t dnnl_memory_desc_permute_axes(memory_desc_t **out_memory_desc, |
529 | const memory_desc_t *in_memory_desc, const int *perm) { |
530 | if (any_null(out_memory_desc, in_memory_desc)) return invalid_arguments; |
531 | |
532 | auto md = utils::make_unique<memory_desc_t>(); |
533 | if (!md) return out_of_memory; |
534 | CHECK(memory_desc_permute_axes(*md, *in_memory_desc, perm)); |
535 | (*out_memory_desc) = md.release(); |
536 | return success; |
537 | } |
538 | |
539 | int dnnl_memory_desc_equal(const memory_desc_t *lhs, const memory_desc_t *rhs) { |
540 | if (lhs == rhs) return 1; |
541 | if (any_null(lhs, rhs)) return 0; |
542 | return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs); |
543 | } |
544 | |
545 | size_t dnnl_memory_desc_get_size(const memory_desc_t *md) { |
546 | if (md == nullptr) return 0; |
547 | return memory_desc_wrapper(*md).size(); |
548 | } |
549 | |
550 | size_t dnnl_data_type_size(dnnl_data_type_t data_type) { |
551 | return types::data_type_size(data_type); |
552 | } |
553 | |
554 | status_t dnnl_memory_desc_query( |
555 | const memory_desc_t *md, query_t what, void *result) { |
556 | const bool is_blocked = md->format_kind == format_kind::blocked; |
557 | |
558 | switch (what) { |
559 | case query::ndims_s32: *(int32_t *)result = md->ndims; break; |
560 | case query::dims: *(const dims_t **)result = &md->dims; break; |
561 | case query::data_type: *(data_type_t *)result = md->data_type; break; |
562 | case query::submemory_offset_s64: *(dim_t *)result = md->offset0; break; |
563 | case query::padded_dims: |
564 | *(const dims_t **)result = &md->padded_dims; |
565 | break; |
566 | case query::padded_offsets: |
567 | *(const dims_t **)result = &md->padded_offsets; |
568 | break; |
569 | case query::format_kind: |
570 | if (one_of(md->format_kind, format_kind::rnn_packed, |
571 | format_kind::wino)) { |
572 | *(format_kind_t *)result = format_kind::opaque; |
573 | break; |
574 | } |
575 | *(format_kind_t *)result = md->format_kind; |
576 | break; |
577 | case query::strides: |
578 | if (!is_blocked) return status::invalid_arguments; |
579 | *(const dims_t **)result = &md->format_desc.blocking.strides; |
580 | break; |
581 | case query::inner_nblks_s32: |
582 | if (!is_blocked) return status::invalid_arguments; |
583 | *(int32_t *)result = md->format_desc.blocking.inner_nblks; |
584 | break; |
585 | case query::inner_blks: |
586 | if (!is_blocked) return status::invalid_arguments; |
587 | *(const dims_t **)result = &md->format_desc.blocking.inner_blks; |
588 | break; |
589 | case query::inner_idxs: |
590 | if (!is_blocked) return status::invalid_arguments; |
591 | *(const dims_t **)result = &md->format_desc.blocking.inner_idxs; |
592 | break; |
593 | default: return status::unimplemented; |
594 | } |
595 | return status::success; |
596 | } |
597 | |
598 | status_t dnnl_memory_desc_destroy(memory_desc_t *memory_desc) { |
599 | delete memory_desc; |
600 | return success; |
601 | } |
602 | |
603 | status_t dnnl_memory_desc_clone(memory_desc_t **memory_desc, |
604 | const memory_desc_t *existing_memory_desc) { |
605 | (*memory_desc) = new memory_desc_t(*existing_memory_desc); |
606 | return success; |
607 | } |
608 | |
609 | // This is an internal API that is used only for testing in benchdnn. |
610 | extern "C" status_t DNNL_API dnnl_memory_desc_create_with_string_tag( |
611 | memory_desc_t **memory_desc, int ndims, const dims_t dims, |
612 | data_type_t data_type, const char *tag) { |
613 | if (any_null(memory_desc)) return invalid_arguments; |
614 | |
615 | auto md = utils::make_unique<memory_desc_t>(); |
616 | if (!md) return out_of_memory; |
617 | CHECK(memory_desc_init_by_string_tag(*md, ndims, dims, data_type, tag)); |
618 | (*memory_desc) = md.release(); |
619 | return success; |
620 | } |
621 | |
622 | extern "C" status_t DNNL_API dnnl_memory_desc_set_data_type( |
623 | memory_desc_t *memory_desc, data_type_t data_type) { |
624 | if (any_null(memory_desc)) return invalid_arguments; |
625 | memory_desc->data_type = data_type; |
626 | return success; |
627 | } |
628 | |