1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "oneapi/dnnl/dnnl.h"
20
21#include "c_types_map.hpp"
22#include "concat_pd.hpp"
23#include "engine.hpp"
24#include "impl_list_item.hpp"
25#include "primitive_cache.hpp"
26#include "primitive_desc_iface.hpp"
27#include "primitive_hashing.hpp"
28#include "type_helpers.hpp"
29#include "utils.hpp"
30
31using namespace dnnl::impl;
32using namespace dnnl::impl::utils;
33using namespace dnnl::impl::status;
34
35namespace dnnl {
36namespace impl {
37
38status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
39 engine_t *engine, const memory_desc_t *dst_md, int n, int concat_dim,
40 const memory_desc_t *const *src_mds, const primitive_attr_t *attr) {
41
42 bool args_ok = !any_null(src_mds) && n > 0;
43 if (!args_ok) return invalid_arguments;
44
45 if (attr == nullptr) attr = &default_attr();
46
47 const int ndims = src_mds[0]->ndims;
48 const dims_t &dims = src_mds[0]->dims;
49 const data_type_t dt = src_mds[0]->data_type;
50 if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides())
51 return unimplemented;
52
53 int concat_dim_sz = dims[concat_dim];
54 if (memory_desc_wrapper(src_mds[0]).format_any()) return invalid_arguments;
55 for (int i = 1; i < n; ++i) {
56 const memory_desc_t &src_md = *src_mds[i];
57 if (src_md.ndims != ndims || memory_desc_wrapper(src_md).format_any())
58 return invalid_arguments;
59 if (memory_desc_wrapper(src_md).has_runtime_dims_or_strides())
60 return unimplemented;
61
62 for (int d = 0; d < ndims; ++d) {
63 if (d == concat_dim) continue;
64 if (src_md.dims[d] != dims[d]) return invalid_arguments;
65 }
66 if (src_md.data_type != dt) return invalid_arguments;
67 concat_dim_sz += src_md.dims[concat_dim];
68 }
69
70 memory_desc_t dummy_dst_md;
71 if (dst_md) {
72 if (dst_md->ndims != ndims) return invalid_arguments;
73 if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides())
74 return unimplemented;
75 for (int d = 0; d < ndims; ++d) {
76 if (dst_md->dims[d] != (d == concat_dim ? concat_dim_sz : dims[d]))
77 return invalid_arguments;
78 }
79 } else {
80 dummy_dst_md = *src_mds[0];
81 dummy_dst_md.dims[concat_dim] = concat_dim_sz;
82 dummy_dst_md.format_kind = format_kind::any;
83 dst_md = &dummy_dst_md;
84 }
85
86 auto desc = concat_desc_t(
87 primitive_kind::concat, dst_md, n, concat_dim, src_mds);
88 primitive_hashing::key_t key(
89 engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
90 pd = primitive_cache().get_pd(key);
91
92 if (pd) return success;
93
94 concat_pd_t *concat_pd = nullptr;
95 for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
96 if ((*c)(&concat_pd, engine, attr, dst_md, n, concat_dim, src_mds)
97 == success) {
98 pd.reset(concat_pd);
99 return success;
100 }
101 }
102 return unimplemented;
103}
104
105status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd,
106 engine_t *engine, const memory_desc_t *dst_md, int n, int concat_dim,
107 const memory_desc_t *src_mds, const primitive_attr_t *attr) {
108 std::vector<const memory_desc_t *> src_mds_ptrs(n);
109 for (int i = 0; i < n; i++)
110 src_mds_ptrs[i] = &src_mds[i];
111 return concat_primitive_desc_create(
112 pd, engine, dst_md, n, concat_dim, src_mds_ptrs.data(), attr);
113}
114
115} // namespace impl
116} // namespace dnnl
117
118status_t dnnl_concat_primitive_desc_create(
119 primitive_desc_iface_t **concat_pd_iface, engine_t *engine,
120 const memory_desc_t *dst_md, int n, int concat_dim,
121 const memory_desc_t *const *src_mds, const primitive_attr_t *attr) {
122 if (any_null(concat_pd_iface)) return invalid_arguments;
123
124 std::shared_ptr<primitive_desc_t> pd;
125 CHECK(concat_primitive_desc_create(
126 pd, engine, dst_md, n, concat_dim, src_mds, attr));
127 return safe_ptr_assign(
128 *concat_pd_iface, new primitive_desc_iface_t(pd, engine));
129}
130