1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "oneapi/dnnl/dnnl.h" |
20 | |
21 | #include "c_types_map.hpp" |
22 | #include "concat_pd.hpp" |
23 | #include "engine.hpp" |
24 | #include "impl_list_item.hpp" |
25 | #include "primitive_cache.hpp" |
26 | #include "primitive_desc_iface.hpp" |
27 | #include "primitive_hashing.hpp" |
28 | #include "type_helpers.hpp" |
29 | #include "utils.hpp" |
30 | |
31 | using namespace dnnl::impl; |
32 | using namespace dnnl::impl::utils; |
33 | using namespace dnnl::impl::status; |
34 | |
35 | namespace dnnl { |
36 | namespace impl { |
37 | |
38 | status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd, |
39 | engine_t *engine, const memory_desc_t *dst_md, int n, int concat_dim, |
40 | const memory_desc_t *const *src_mds, const primitive_attr_t *attr) { |
41 | |
42 | bool args_ok = !any_null(src_mds) && n > 0; |
43 | if (!args_ok) return invalid_arguments; |
44 | |
45 | if (attr == nullptr) attr = &default_attr(); |
46 | |
47 | const int ndims = src_mds[0]->ndims; |
48 | const dims_t &dims = src_mds[0]->dims; |
49 | const data_type_t dt = src_mds[0]->data_type; |
50 | if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides()) |
51 | return unimplemented; |
52 | |
53 | int concat_dim_sz = dims[concat_dim]; |
54 | if (memory_desc_wrapper(src_mds[0]).format_any()) return invalid_arguments; |
55 | for (int i = 1; i < n; ++i) { |
56 | const memory_desc_t &src_md = *src_mds[i]; |
57 | if (src_md.ndims != ndims || memory_desc_wrapper(src_md).format_any()) |
58 | return invalid_arguments; |
59 | if (memory_desc_wrapper(src_md).has_runtime_dims_or_strides()) |
60 | return unimplemented; |
61 | |
62 | for (int d = 0; d < ndims; ++d) { |
63 | if (d == concat_dim) continue; |
64 | if (src_md.dims[d] != dims[d]) return invalid_arguments; |
65 | } |
66 | if (src_md.data_type != dt) return invalid_arguments; |
67 | concat_dim_sz += src_md.dims[concat_dim]; |
68 | } |
69 | |
70 | memory_desc_t dummy_dst_md; |
71 | if (dst_md) { |
72 | if (dst_md->ndims != ndims) return invalid_arguments; |
73 | if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides()) |
74 | return unimplemented; |
75 | for (int d = 0; d < ndims; ++d) { |
76 | if (dst_md->dims[d] != (d == concat_dim ? concat_dim_sz : dims[d])) |
77 | return invalid_arguments; |
78 | } |
79 | } else { |
80 | dummy_dst_md = *src_mds[0]; |
81 | dummy_dst_md.dims[concat_dim] = concat_dim_sz; |
82 | dummy_dst_md.format_kind = format_kind::any; |
83 | dst_md = &dummy_dst_md; |
84 | } |
85 | |
86 | auto desc = concat_desc_t( |
87 | primitive_kind::concat, dst_md, n, concat_dim, src_mds); |
88 | primitive_hashing::key_t key( |
89 | engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {}); |
90 | pd = primitive_cache().get_pd(key); |
91 | |
92 | if (pd) return success; |
93 | |
94 | concat_pd_t *concat_pd = nullptr; |
95 | for (auto c = engine->get_concat_implementation_list(); *c; ++c) { |
96 | if ((*c)(&concat_pd, engine, attr, dst_md, n, concat_dim, src_mds) |
97 | == success) { |
98 | pd.reset(concat_pd); |
99 | return success; |
100 | } |
101 | } |
102 | return unimplemented; |
103 | } |
104 | |
105 | status_t concat_primitive_desc_create(std::shared_ptr<primitive_desc_t> &pd, |
106 | engine_t *engine, const memory_desc_t *dst_md, int n, int concat_dim, |
107 | const memory_desc_t *src_mds, const primitive_attr_t *attr) { |
108 | std::vector<const memory_desc_t *> src_mds_ptrs(n); |
109 | for (int i = 0; i < n; i++) |
110 | src_mds_ptrs[i] = &src_mds[i]; |
111 | return concat_primitive_desc_create( |
112 | pd, engine, dst_md, n, concat_dim, src_mds_ptrs.data(), attr); |
113 | } |
114 | |
115 | } // namespace impl |
116 | } // namespace dnnl |
117 | |
118 | status_t dnnl_concat_primitive_desc_create( |
119 | primitive_desc_iface_t **concat_pd_iface, engine_t *engine, |
120 | const memory_desc_t *dst_md, int n, int concat_dim, |
121 | const memory_desc_t *const *src_mds, const primitive_attr_t *attr) { |
122 | if (any_null(concat_pd_iface)) return invalid_arguments; |
123 | |
124 | std::shared_ptr<primitive_desc_t> pd; |
125 | CHECK(concat_primitive_desc_create( |
126 | pd, engine, dst_md, n, concat_dim, src_mds, attr)); |
127 | return safe_ptr_assign( |
128 | *concat_pd_iface, new primitive_desc_iface_t(pd, engine)); |
129 | } |
130 | |