1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18#include "oneapi/dnnl/dnnl.h"
19
20#include "c_types_map.hpp"
21#include "opdesc.hpp"
22#include "primitive_desc_iface.hpp"
23#include "type_helpers.hpp"
24#include "utils.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::utils;
28using namespace dnnl::impl::status;
29using namespace dnnl::impl::prop_kind;
30using namespace dnnl::impl::types;
31
32namespace {
33status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind,
34 const memory_desc_t *src_desc, const memory_desc_t *dst_desc, int axis,
35 dim_t group_size) {
36 bool args_ok = !any_null(shuffle_desc, src_desc, dst_desc)
37 && one_of(prop_kind, forward_training, forward_inference,
38 backward_data)
39 && IMPLICATION(prop_kind != backward_data,
40 src_desc->format_kind != format_kind::any)
41 && axis >= 0 && axis < src_desc->ndims && group_size > 0
42 && group_size <= src_desc->dims[axis];
43 if (!args_ok) return invalid_arguments;
44
45 if (memory_desc_wrapper(src_desc).has_runtime_dims_or_strides()
46 || memory_desc_wrapper(dst_desc).has_runtime_dims_or_strides())
47 return unimplemented;
48
49 auto sd = shuffle_desc_t();
50 sd.primitive_kind = primitive_kind::shuffle;
51 sd.prop_kind = prop_kind;
52 sd.src_desc = *src_desc;
53 sd.dst_desc = *dst_desc;
54 sd.axis = axis;
55 sd.group_size = group_size;
56
57 bool consistency = sd.src_desc.dims[axis] % sd.group_size == 0
58 && sd.dst_desc.ndims == sd.src_desc.ndims
59 && array_cmp(sd.dst_desc.dims, sd.src_desc.dims, sd.src_desc.ndims);
60 if (!consistency) return invalid_arguments;
61
62 *shuffle_desc = sd;
63 return success;
64}
65} // namespace
66
67dnnl_status_t dnnl_shuffle_forward_primitive_desc_create(
68 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
69 prop_kind_t prop_kind, const memory_desc_t *src_desc,
70 const memory_desc_t *dst_desc, int axis, dim_t group_size,
71 const primitive_attr_t *attr) {
72 if (!one_of(prop_kind, forward_training, forward_inference))
73 return invalid_arguments;
74
75 auto shuffle_desc = shuffle_desc_t();
76 CHECK(shuffle_desc_init(
77 &shuffle_desc, prop_kind, src_desc, dst_desc, axis, group_size));
78 return primitive_desc_create(primitive_desc_iface, engine,
79 (const op_desc_t *)&shuffle_desc, nullptr, attr);
80}
81
82dnnl_status_t dnnl_shuffle_backward_primitive_desc_create(
83 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
84 const memory_desc_t *diff_src_desc, const memory_desc_t *diff_dst_desc,
85 int axis, dim_t group_size, const primitive_desc_iface_t *hint_fwd_pd,
86 const primitive_attr_t *attr) {
87
88 auto shuffle_desc = shuffle_desc_t();
89 CHECK(shuffle_desc_init(&shuffle_desc, backward_data, diff_src_desc,
90 diff_dst_desc, axis, group_size));
91 return primitive_desc_create(primitive_desc_iface, engine,
92 (const op_desc_t *)&shuffle_desc, hint_fwd_pd, attr);
93}
94
95// vim: et ts=5 sw=4 cindent cino+=l0,\:4,N-s
96