1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <assert.h> |
18 | |
19 | #include "oneapi/dnnl/dnnl.h" |
20 | |
21 | #include "c_types_map.hpp" |
22 | #include "engine.hpp" |
23 | #include "impl_list_item.hpp" |
24 | #include "primitive_cache.hpp" |
25 | #include "primitive_desc_iface.hpp" |
26 | #include "primitive_hashing.hpp" |
27 | #include "type_helpers.hpp" |
28 | #include "utils.hpp" |
29 | |
30 | #include "sum_pd.hpp" |
31 | |
32 | using namespace dnnl::impl; |
33 | using namespace dnnl::impl::utils; |
34 | using namespace dnnl::impl::status; |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | |
39 | status_t sum_primitive_desc_create(primitive_desc_iface_t **sum_pd_iface, |
40 | const memory_desc_t *dst_md, int n, const float *scales, |
41 | const memory_desc_t *const *src_mds, const primitive_attr_t *attr, |
42 | engine_t *engine) { |
43 | |
44 | bool args_ok = !any_null(sum_pd_iface, src_mds, scales) && n > 0; |
45 | if (!args_ok) return invalid_arguments; |
46 | |
47 | if (attr == nullptr) attr = &default_attr(); |
48 | |
49 | const int ndims = src_mds[0]->ndims; |
50 | const dims_t &dims = src_mds[0]->dims; |
51 | if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides()) |
52 | return unimplemented; |
53 | |
54 | if (memory_desc_wrapper(src_mds[0]).format_any()) return invalid_arguments; |
55 | for (int i = 1; i < n; ++i) { |
56 | const memory_desc_t &src_md = *src_mds[i]; |
57 | if (src_md.ndims != ndims || memory_desc_wrapper(src_md).format_any()) |
58 | return invalid_arguments; |
59 | if (memory_desc_wrapper(src_md).has_runtime_dims_or_strides()) |
60 | return unimplemented; |
61 | for (int d = 0; d < ndims; ++d) { |
62 | if (src_md.dims[d] != dims[d]) return invalid_arguments; |
63 | } |
64 | } |
65 | |
66 | memory_desc_t dummy_dst_md; |
67 | if (dst_md) { |
68 | if (dst_md->ndims != ndims) return invalid_arguments; |
69 | if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides()) |
70 | return unimplemented; |
71 | for (int d = 0; d < ndims; ++d) { |
72 | if (dst_md->dims[d] != dims[d]) return invalid_arguments; |
73 | } |
74 | } else { |
75 | dummy_dst_md = *src_mds[0]; |
76 | dummy_dst_md.format_kind = format_kind::any; |
77 | dst_md = &dummy_dst_md; |
78 | } |
79 | |
80 | auto desc = sum_desc_t(primitive_kind::sum, dst_md, n, scales, src_mds); |
81 | primitive_hashing::key_t key( |
82 | engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {}); |
83 | auto pd = primitive_cache().get_pd(key); |
84 | |
85 | if (pd) { |
86 | return safe_ptr_assign( |
87 | *sum_pd_iface, new primitive_desc_iface_t(pd, engine)); |
88 | } |
89 | |
90 | for (auto s = engine->get_sum_implementation_list(); *s; ++s) { |
91 | sum_pd_t *sum_pd = nullptr; |
92 | if ((*s)(&sum_pd, engine, attr, dst_md, n, scales, src_mds) |
93 | == success) { |
94 | pd.reset(sum_pd); |
95 | CHECK(safe_ptr_assign( |
96 | *sum_pd_iface, new primitive_desc_iface_t(pd, engine))); |
97 | return status::success; |
98 | } |
99 | } |
100 | return unimplemented; |
101 | } |
102 | |
103 | } // namespace impl |
104 | } // namespace dnnl |
105 | |
106 | status_t dnnl_sum_primitive_desc_create(primitive_desc_iface_t **sum_pd_iface, |
107 | engine_t *engine, const memory_desc_t *dst_md, int n, |
108 | const float *scales, const memory_desc_t *const *src_mds, |
109 | const primitive_attr_t *attr) { |
110 | return sum_primitive_desc_create( |
111 | sum_pd_iface, dst_md, n, scales, src_mds, attr, engine); |
112 | } |
113 | |