1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <cctype>
18
19#include "oneapi/dnnl/dnnl.hpp"
20
21#include "common/c_types_map.hpp"
22#include "common/memory_desc.hpp"
23#include "common/memory_desc_wrapper.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27using namespace dnnl::impl;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::utils;
30
31namespace dnnl {
32namespace impl {
33
34status_t memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims,
35 const dims_t dims, data_type_t data_type, format_tag_t tag) {
36 if (ndims == 0 || tag == format_tag::undef) {
37 memory_desc = types::zero_md();
38 return success;
39 }
40
41 format_kind_t format_kind = types::format_tag_to_kind(tag);
42
43 /* memory_desc != 0 */
44 bool args_ok
45 = memory_desc_sanity_check(ndims, dims, data_type, format_kind);
46 if (!args_ok) return invalid_arguments;
47
48 auto md = memory_desc_t();
49 md.ndims = ndims;
50 array_copy(md.dims, dims, ndims);
51 md.data_type = data_type;
52 array_copy(md.padded_dims, dims, ndims);
53 md.format_kind = format_kind;
54
55 status_t status = success;
56 if (tag == format_tag::undef) {
57 status = invalid_arguments;
58 } else if (tag == format_tag::any) {
59 // nop
60 } else if (format_kind == format_kind::blocked) {
61 status = memory_desc_wrapper::compute_blocking(md, tag);
62 } else {
63 assert(!"unreachable");
64 status = invalid_arguments;
65 }
66
67 if (status == success) memory_desc = md;
68
69 return status;
70}
71
72status_t memory_desc_init_by_strides(memory_desc_t &memory_desc, int ndims,
73 const dims_t dims, data_type_t data_type, const dims_t strides) {
74 if (ndims == 0) {
75 memory_desc = types::zero_md();
76 return success;
77 }
78
79 /* memory_desc != 0 */
80 bool args_ok = memory_desc_sanity_check(
81 ndims, dims, data_type, format_kind::undef);
82 if (!args_ok) return invalid_arguments;
83
84 auto md = memory_desc_t();
85 md.ndims = ndims;
86 array_copy(md.dims, dims, ndims);
87 md.data_type = data_type;
88 array_copy(md.padded_dims, dims, ndims);
89 md.format_kind = format_kind::blocked;
90
91 dims_t default_strides = {0};
92 if (strides == nullptr) {
93 bool has_runtime_strides = false;
94 default_strides[md.ndims - 1] = 1;
95 for (int d = md.ndims - 2; d >= 0; --d) {
96 if (md.padded_dims[d] == DNNL_RUNTIME_DIM_VAL)
97 has_runtime_strides = true;
98 default_strides[d] = has_runtime_strides
99 ? DNNL_RUNTIME_DIM_VAL
100 : default_strides[d + 1] * md.padded_dims[d + 1];
101 }
102 strides = default_strides;
103 }
104 if (!memory_desc_strides_check(md, strides)) return invalid_arguments;
105
106 array_copy(md.format_desc.blocking.strides, strides, md.ndims);
107
108 memory_desc = md;
109
110 return success;
111}
112
113status_t memory_desc_init_submemory(memory_desc_t &memory_desc,
114 const memory_desc_t &parent_memory_desc, const dims_t dims,
115 const dims_t offsets) {
116 if (!memory_desc_sanity_check(parent_memory_desc)) return invalid_arguments;
117
118 const memory_desc_wrapper src_d(parent_memory_desc);
119 if (src_d.has_runtime_dims_or_strides()) return unimplemented;
120
121 for (int d = 0; d < src_d.ndims(); ++d) {
122 if (utils::one_of(DNNL_RUNTIME_DIM_VAL, dims[d], offsets[d]))
123 return unimplemented;
124
125 if (dims[d] < 0 || offsets[d] < 0
126 || (offsets[d] + dims[d] > src_d.dims()[d]))
127 return invalid_arguments;
128 }
129
130 if (src_d.format_kind() != format_kind::blocked) return unimplemented;
131
132 dims_t blocks;
133 src_d.compute_blocks(blocks);
134
135 memory_desc_t dst_d = parent_memory_desc;
136 auto &dst_d_blk = dst_d.format_desc.blocking;
137
138 /* TODO: put this into memory_desc_wrapper */
139 for (int d = 0; d < src_d.ndims(); ++d) {
140 const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
141
142 /* very limited functionality for now */
143 const bool ok = offsets[d] % blocks[d] == 0 /* [r1] */
144 && src_d.padded_offsets()[d] == 0
145 && IMPLICATION(!is_right_border,
146 (dims[d] % blocks[d] == 0 || dims[d] < blocks[d]));
147 if (!ok) return unimplemented;
148
149 dst_d.dims[d] = dims[d];
150 dst_d.padded_dims[d] = is_right_border
151 ? src_d.padded_dims()[d] - offsets[d]
152 : dst_d.dims[d];
153 dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
154 dst_d.offset0 += /* [r1] */
155 offsets[d] / blocks[d] * dst_d_blk.strides[d];
156 }
157
158 memory_desc = dst_d;
159
160 return success;
161}
162
163status_t memory_desc_reshape(memory_desc_t &out_memory_desc,
164 const memory_desc_t &in_memory_desc, int ndims, const dims_t dims) {
165 auto volume = [](const dim_t *dims, int ndims) -> dim_t {
166 dim_t prod = 1;
167 for (int i = 0; i < ndims; ++i) {
168 if (dims[i] == DNNL_RUNTIME_DIM_VAL) return DNNL_RUNTIME_DIM_VAL;
169 prod *= dims[i] > 0 ? dims[i] : 1;
170 }
171 return prod;
172 };
173
174 if (!memory_desc_sanity_check(in_memory_desc)
175 || !memory_desc_sanity_check(ndims, dims, in_memory_desc.data_type,
176 in_memory_desc.format_kind)
177 || !one_of(in_memory_desc.format_kind, format_kind::any,
178 format_kind::blocked)
179 || types::is_zero_md(&in_memory_desc)
180 || volume(in_memory_desc.dims, in_memory_desc.ndims)
181 != volume(dims, ndims)
182 || memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides()
183 || in_memory_desc.extra.flags != 0)
184 return invalid_arguments;
185
186 if (in_memory_desc.format_kind == format_kind::any)
187 return memory_desc_init_by_tag(out_memory_desc, ndims, dims,
188 in_memory_desc.data_type, format_tag::any);
189
190 assert(in_memory_desc.format_kind == format_kind::blocked);
191 assert(in_memory_desc.extra.flags == 0);
192
193 // temporary output
194 auto md = in_memory_desc;
195
196 md.ndims = ndims;
197 array_copy(md.dims, dims, md.ndims);
198
199 const int i_ndims = in_memory_desc.ndims;
200 const int o_ndims = md.ndims;
201
202 const auto &i_dims = in_memory_desc.dims,
203 &i_pdims = in_memory_desc.padded_dims;
204 const auto &o_dims = md.dims;
205
206 const auto &i_bd = in_memory_desc.format_desc.blocking;
207 auto &o_bd = md.format_desc.blocking;
208
209 dims_t blocks = {0};
210 memory_desc_wrapper(in_memory_desc).compute_blocks(blocks);
211
212 enum class action_t { REMOVE_1, ADD_1, KEEP_DIM, REARRANGE_DIMS, FAIL };
213
214 // Determine groups in input and output dims starting with the given
215 // positions (going backwards) that satisfy one of the conditions:
216 // - REMOVE_1
217 // input_group = {1}, output_group = empty
218 // - ADD_1
219 // input_group = empty, output_group = {1}
220 // - KEEP_DIM
221 // input_group = {x}, output_group = {x}
222 // - REARRANGE_DIMS
223 // input_group = {x1, x2, .., xk}, output_group = {y1, y2, ..., ym}
224 // and product(x_i) = product(y_j), and the groups are minimal
225 // - FAIL
226 // invalid configuration (return false)
227 auto find_groups
228 = [&](int &i_group_begin, int i_group_end, int &o_group_begin,
229 int o_group_end) -> action_t {
230 // 1st step: check for `1` in the input dims
231 if (i_group_end > 0 && i_dims[i_group_end - 1] == 1) {
232 i_group_begin = i_group_end - 1;
233 if (i_pdims[i_group_end - 1] == 1) {
234 o_group_begin = o_group_end;
235 return action_t::REMOVE_1;
236 } else if (o_group_end > 0 && o_dims[o_group_end - 1] == 1) {
237 o_group_begin = o_group_end - 1;
238 return action_t::KEEP_DIM;
239 } else {
240 return action_t::FAIL;
241 }
242 }
243
244 // 2nd step: check for `1` in the output dims
245 if (o_group_end > 0 && o_dims[o_group_end - 1] == 1) {
246 i_group_begin = i_group_end;
247 o_group_begin = o_group_end - 1;
248 return action_t::ADD_1;
249 }
250
251 // at this moment both groups cannot be empty
252 if (i_group_end == 0 || o_group_end == 0) return action_t::FAIL;
253
254 // 3rd step: find the non-trivial groups of the same volume
255 i_group_begin = i_group_end - 1;
256 o_group_begin = o_group_end - 1;
257
258 dim_t i_volume = i_dims[i_group_begin];
259 dim_t o_volume = o_dims[o_group_begin];
260
261 while (i_volume != o_volume) {
262 if (i_volume < o_volume) {
263 if (i_group_begin == 0) return action_t::FAIL;
264 i_volume *= i_dims[--i_group_begin];
265
266 // do not allow `0` axis in the middle
267 if (i_volume == 0) return action_t::FAIL;
268 } else {
269 if (o_group_begin == 0) return action_t::FAIL;
270 o_volume *= o_dims[--o_group_begin];
271
272 // do not allow `0` axis in the middle
273 if (o_volume == 0) return action_t::FAIL;
274 }
275 }
276
277 assert(i_volume == o_volume);
278 assert(i_group_begin >= 0);
279 assert(o_group_begin >= 0);
280
281 return (i_group_begin + 1 == i_group_end
282 && o_group_begin + 1 == o_group_end)
283 ? action_t::KEEP_DIM
284 : action_t::REARRANGE_DIMS;
285 };
286
287 int i_group_begin = i_ndims, i_group_end = i_ndims;
288 int o_group_begin = o_ndims, o_group_end = o_ndims;
289
290 while (i_group_end != 0 || o_group_end != 0) {
291 action_t action = find_groups(
292 i_group_begin, i_group_end, o_group_begin, o_group_end);
293
294 if (action == action_t::REMOVE_1) {
295 // nop, padding is already taken into account by `find_groups()`
296 } else if (action == action_t::ADD_1) {
297 // get the stride from the right
298 dim_t current_stride = 1;
299 if (i_group_begin == i_ndims) {
300 for (int d = 0; d < i_bd.inner_nblks; ++d)
301 current_stride *= i_bd.inner_blks[d];
302 } else {
303 // Add `1` to the left from axes with index `i_group_begin`
304 current_stride
305 = i_bd.strides[i_group_begin] * i_dims[i_group_begin];
306 for (int d = 0; d < i_bd.inner_nblks; ++d)
307 if (i_bd.inner_idxs[d] == i_group_begin)
308 current_stride /= i_bd.inner_blks[d];
309 }
310 md.padded_dims[o_group_begin] = 1;
311 md.padded_offsets[o_group_begin] = 0;
312 o_bd.strides[o_group_begin] = current_stride;
313 } else if (action == action_t::KEEP_DIM) {
314 // change the axis index from `i_group_begin` to `o_group_begin`
315 assert(i_group_begin + 1 == i_group_end);
316 assert(o_group_begin + 1 == o_group_end);
317
318 md.padded_dims[o_group_begin]
319 = in_memory_desc.padded_dims[i_group_begin];
320 md.padded_offsets[o_group_begin]
321 = in_memory_desc.padded_offsets[i_group_begin];
322 o_bd.strides[o_group_begin] = i_bd.strides[i_group_begin];
323 for (int d = 0; d < i_bd.inner_nblks; ++d)
324 if (i_bd.inner_idxs[d] == i_group_begin)
325 o_bd.inner_idxs[d] = o_group_begin;
326 } else if (action == action_t::REARRANGE_DIMS) {
327 // check that input group is dense, sequential, and is not blocked
328 for (int d = i_group_end - 1; d > i_group_begin; --d)
329 if (i_dims[d] * i_bd.strides[d] != i_bd.strides[d - 1])
330 return invalid_arguments;
331
332 // checked (i_group_begin, i_group_end), `i_group_begin` remains
333 for (int d = 0; d < i_bd.inner_nblks; ++d)
334 if (i_bd.inner_idxs[d] == i_group_begin)
335 return invalid_arguments;
336 if (in_memory_desc.padded_dims[i_group_begin]
337 != i_dims[i_group_begin])
338 return invalid_arguments;
339 if (in_memory_desc.padded_offsets[i_group_begin] != 0)
340 return invalid_arguments;
341
342 // oK, fill output md according to
343 // o_dims[o_group_begin .. o_group_end]
344
345 dim_t current_stride = i_bd.strides[i_group_end - 1];
346 for (int d = o_group_end - 1; d >= o_group_begin; --d) {
347 md.padded_dims[d] = o_dims[d];
348 md.padded_offsets[d] = 0;
349 o_bd.strides[d] = current_stride;
350 current_stride *= md.padded_dims[d];
351 }
352 } else {
353 assert(action == action_t::FAIL);
354 return invalid_arguments;
355 }
356
357 i_group_end = i_group_begin;
358 o_group_end = o_group_begin;
359 }
360
361 out_memory_desc = md;
362 return success;
363}
364
365status_t memory_desc_permute_axes(memory_desc_t &out_memory_desc,
366 const memory_desc_t &in_memory_desc, const int *perm) {
367 if (!memory_desc_sanity_check(in_memory_desc)
368 || !one_of(in_memory_desc.format_kind, format_kind::any,
369 format_kind::blocked)
370 || types::is_zero_md(&in_memory_desc)
371 || memory_desc_wrapper(in_memory_desc).has_runtime_dims_or_strides()
372 || in_memory_desc.extra.flags != 0)
373 return invalid_arguments;
374
375 // verify that perm is indeed a permutation of [0 .. ndims)
376 unsigned occurrence_mask = 0;
377 for (int d = 0; d < in_memory_desc.ndims; ++d)
378 if (0 <= perm[d] && perm[d] < in_memory_desc.ndims)
379 occurrence_mask |= (1u << perm[d]);
380 if (occurrence_mask + 1 != (1u << in_memory_desc.ndims))
381 return invalid_arguments;
382
383 out_memory_desc = in_memory_desc;
384 for (int d = 0; d < in_memory_desc.ndims; ++d) {
385 if (perm[d] == d) continue;
386 out_memory_desc.dims[perm[d]] = in_memory_desc.dims[d];
387 out_memory_desc.padded_dims[perm[d]] = in_memory_desc.padded_dims[d];
388 out_memory_desc.padded_offsets[perm[d]]
389 = in_memory_desc.padded_offsets[d];
390 if (in_memory_desc.format_kind == format_kind::blocked) {
391 const auto &i_bd = in_memory_desc.format_desc.blocking;
392 auto &o_bd = out_memory_desc.format_desc.blocking;
393
394 o_bd.strides[perm[d]] = i_bd.strides[d];
395 for (int blk = 0; blk < i_bd.inner_nblks; ++blk)
396 if (i_bd.inner_idxs[blk] == d) o_bd.inner_idxs[blk] = perm[d];
397 }
398 }
399
400 return success;
401}
402
403// This is only used by internal API that is used for testing only.
404status_t memory_desc_init_by_string_tag(memory_desc_t &md, int ndims,
405 const dims_t dims, data_type_t data_type, const std::string &tag) {
406 // Copy to temporary to handle dims == md->dims case.
407 dims_t tmp_dims;
408 std::copy(dims, dims + ndims, tmp_dims);
409
410 md.ndims = ndims;
411 if (ndims < 0 || ndims > DNNL_MAX_NDIMS) return invalid_arguments;
412
413 std::copy(tmp_dims, tmp_dims + ndims, md.dims);
414 md.data_type = data_type;
415 md.format_kind = format_kind::blocked;
416
417 // Parse dimensions and their block sizes starting from the innermost one.
418 std::vector<std::pair<int, int>> dim_blocks;
419 int pos = (int)tag.size() - 1;
420 int ndims_from_tag = -1;
421 while (pos >= 0) {
422 int pos0 = pos;
423
424 --pos;
425 while (pos >= 0 && std::isdigit(tag[pos]))
426 pos--;
427
428 int dim_idx = std::tolower(tag[pos0]) - 'a';
429 if (dim_idx >= ndims) return invalid_arguments;
430 ndims_from_tag = std::max(dim_idx + 1, ndims_from_tag);
431 int block_str_len = pos0 - pos - 1;
432 int block = (block_str_len == 0)
433 ? 1
434 : std::stoi(tag.substr(pos + 1, block_str_len));
435 dim_blocks.emplace_back(dim_idx, block);
436 }
437 if (ndims_from_tag != ndims) return invalid_arguments;
438
439 auto &blk = md.format_desc.blocking;
440
441 // Compute strides and fill inner block sizes/indices.
442 dim_t stride = 1;
443 dims_t full_inner_blks;
444 std::fill(full_inner_blks, full_inner_blks + ndims, 1);
445 for (auto &p : dim_blocks) {
446 int dim_idx = p.first;
447 int block = p.second;
448 if (block == 1) {
449 assert(blk.strides[dim_idx] == 0);
450 blk.strides[dim_idx] = stride;
451
452 dim_t fib = full_inner_blks[dim_idx];
453 dim_t padded_dim = md.dims[dim_idx] == DNNL_RUNTIME_DIM_VAL
454 ? DNNL_RUNTIME_DIM_VAL
455 : (md.dims[dim_idx] + fib - 1) / fib * fib;
456 md.padded_dims[dim_idx] = padded_dim;
457 if (padded_dim == DNNL_RUNTIME_DIM_VAL)
458 stride = DNNL_RUNTIME_DIM_VAL;
459 else
460 stride *= (padded_dim / fib);
461 } else {
462 full_inner_blks[dim_idx] *= block;
463 blk.inner_blks[blk.inner_nblks] = block;
464 blk.inner_idxs[blk.inner_nblks] = dim_idx;
465 blk.inner_nblks++;
466 stride *= block;
467 }
468 }
469
470 // Inner block sizes/indices are stored from the outermost to the innermost
471 // so need to reverse them.
472 std::reverse(blk.inner_blks, blk.inner_blks + blk.inner_nblks);
473 std::reverse(blk.inner_idxs, blk.inner_idxs + blk.inner_nblks);
474
475 return success;
476}
477
478} // namespace impl
479} // namespace dnnl
480
481// API
482status_t dnnl_memory_desc_create_with_tag(memory_desc_t **memory_desc,
483 int ndims, const dims_t dims, data_type_t data_type, format_tag_t tag) {
484 if (any_null(memory_desc)) return invalid_arguments;
485
486 auto md = utils::make_unique<memory_desc_t>();
487 if (!md) return out_of_memory;
488 CHECK(memory_desc_init_by_tag(*md, ndims, dims, data_type, tag));
489 (*memory_desc) = md.release();
490 return success;
491}
492
493status_t dnnl_memory_desc_create_with_strides(memory_desc_t **memory_desc,
494 int ndims, const dims_t dims, data_type_t data_type,
495 const dims_t strides) {
496 if (any_null(memory_desc)) return invalid_arguments;
497
498 auto md = utils::make_unique<memory_desc_t>();
499 if (!md) return out_of_memory;
500 CHECK(memory_desc_init_by_strides(*md, ndims, dims, data_type, strides));
501 (*memory_desc) = md.release();
502 return success;
503}
504
505status_t dnnl_memory_desc_create_submemory(memory_desc_t **memory_desc,
506 const memory_desc_t *parent_memory_desc, const dims_t dims,
507 const dims_t offsets) {
508 if (any_null(memory_desc, parent_memory_desc)) return invalid_arguments;
509
510 auto md = utils::make_unique<memory_desc_t>();
511 if (!md) return out_of_memory;
512 CHECK(memory_desc_init_submemory(*md, *parent_memory_desc, dims, offsets));
513 (*memory_desc) = md.release();
514 return success;
515}
516
517status_t dnnl_memory_desc_reshape(memory_desc_t **out_memory_desc,
518 const memory_desc_t *in_memory_desc, int ndims, const dims_t dims) {
519 if (any_null(out_memory_desc, in_memory_desc)) return invalid_arguments;
520
521 auto md = utils::make_unique<memory_desc_t>();
522 if (!md) return out_of_memory;
523 CHECK(memory_desc_reshape(*md, *in_memory_desc, ndims, dims));
524 (*out_memory_desc) = md.release();
525 return success;
526}
527
528status_t dnnl_memory_desc_permute_axes(memory_desc_t **out_memory_desc,
529 const memory_desc_t *in_memory_desc, const int *perm) {
530 if (any_null(out_memory_desc, in_memory_desc)) return invalid_arguments;
531
532 auto md = utils::make_unique<memory_desc_t>();
533 if (!md) return out_of_memory;
534 CHECK(memory_desc_permute_axes(*md, *in_memory_desc, perm));
535 (*out_memory_desc) = md.release();
536 return success;
537}
538
539int dnnl_memory_desc_equal(const memory_desc_t *lhs, const memory_desc_t *rhs) {
540 if (lhs == rhs) return 1;
541 if (any_null(lhs, rhs)) return 0;
542 return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
543}
544
545size_t dnnl_memory_desc_get_size(const memory_desc_t *md) {
546 if (md == nullptr) return 0;
547 return memory_desc_wrapper(*md).size();
548}
549
550size_t dnnl_data_type_size(dnnl_data_type_t data_type) {
551 return types::data_type_size(data_type);
552}
553
554status_t dnnl_memory_desc_query(
555 const memory_desc_t *md, query_t what, void *result) {
556 const bool is_blocked = md->format_kind == format_kind::blocked;
557
558 switch (what) {
559 case query::ndims_s32: *(int32_t *)result = md->ndims; break;
560 case query::dims: *(const dims_t **)result = &md->dims; break;
561 case query::data_type: *(data_type_t *)result = md->data_type; break;
562 case query::submemory_offset_s64: *(dim_t *)result = md->offset0; break;
563 case query::padded_dims:
564 *(const dims_t **)result = &md->padded_dims;
565 break;
566 case query::padded_offsets:
567 *(const dims_t **)result = &md->padded_offsets;
568 break;
569 case query::format_kind:
570 if (one_of(md->format_kind, format_kind::rnn_packed,
571 format_kind::wino)) {
572 *(format_kind_t *)result = format_kind::opaque;
573 break;
574 }
575 *(format_kind_t *)result = md->format_kind;
576 break;
577 case query::strides:
578 if (!is_blocked) return status::invalid_arguments;
579 *(const dims_t **)result = &md->format_desc.blocking.strides;
580 break;
581 case query::inner_nblks_s32:
582 if (!is_blocked) return status::invalid_arguments;
583 *(int32_t *)result = md->format_desc.blocking.inner_nblks;
584 break;
585 case query::inner_blks:
586 if (!is_blocked) return status::invalid_arguments;
587 *(const dims_t **)result = &md->format_desc.blocking.inner_blks;
588 break;
589 case query::inner_idxs:
590 if (!is_blocked) return status::invalid_arguments;
591 *(const dims_t **)result = &md->format_desc.blocking.inner_idxs;
592 break;
593 default: return status::unimplemented;
594 }
595 return status::success;
596}
597
598status_t dnnl_memory_desc_destroy(memory_desc_t *memory_desc) {
599 delete memory_desc;
600 return success;
601}
602
603status_t dnnl_memory_desc_clone(memory_desc_t **memory_desc,
604 const memory_desc_t *existing_memory_desc) {
605 (*memory_desc) = new memory_desc_t(*existing_memory_desc);
606 return success;
607}
608
609// This is an internal API that is used only for testing in benchdnn.
610extern "C" status_t DNNL_API dnnl_memory_desc_create_with_string_tag(
611 memory_desc_t **memory_desc, int ndims, const dims_t dims,
612 data_type_t data_type, const char *tag) {
613 if (any_null(memory_desc)) return invalid_arguments;
614
615 auto md = utils::make_unique<memory_desc_t>();
616 if (!md) return out_of_memory;
617 CHECK(memory_desc_init_by_string_tag(*md, ndims, dims, data_type, tag));
618 (*memory_desc) = md.release();
619 return success;
620}
621
622extern "C" status_t DNNL_API dnnl_memory_desc_set_data_type(
623 memory_desc_t *memory_desc, data_type_t data_type) {
624 if (any_null(memory_desc)) return invalid_arguments;
625 memory_desc->data_type = data_type;
626 return success;
627}
628