1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <assert.h>
18
19#include "oneapi/dnnl/dnnl.h"
20
21#include "c_types_map.hpp"
22#include "engine.hpp"
23#include "impl_list_item.hpp"
24#include "primitive_cache.hpp"
25#include "primitive_desc_iface.hpp"
26#include "primitive_hashing.hpp"
27#include "type_helpers.hpp"
28#include "utils.hpp"
29
30#include "sum_pd.hpp"
31
32using namespace dnnl::impl;
33using namespace dnnl::impl::utils;
34using namespace dnnl::impl::status;
35
36namespace dnnl {
37namespace impl {
38
39status_t sum_primitive_desc_create(primitive_desc_iface_t **sum_pd_iface,
40 const memory_desc_t *dst_md, int n, const float *scales,
41 const memory_desc_t *const *src_mds, const primitive_attr_t *attr,
42 engine_t *engine) {
43
44 bool args_ok = !any_null(sum_pd_iface, src_mds, scales) && n > 0;
45 if (!args_ok) return invalid_arguments;
46
47 if (attr == nullptr) attr = &default_attr();
48
49 const int ndims = src_mds[0]->ndims;
50 const dims_t &dims = src_mds[0]->dims;
51 if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides())
52 return unimplemented;
53
54 if (memory_desc_wrapper(src_mds[0]).format_any()) return invalid_arguments;
55 for (int i = 1; i < n; ++i) {
56 const memory_desc_t &src_md = *src_mds[i];
57 if (src_md.ndims != ndims || memory_desc_wrapper(src_md).format_any())
58 return invalid_arguments;
59 if (memory_desc_wrapper(src_md).has_runtime_dims_or_strides())
60 return unimplemented;
61 for (int d = 0; d < ndims; ++d) {
62 if (src_md.dims[d] != dims[d]) return invalid_arguments;
63 }
64 }
65
66 memory_desc_t dummy_dst_md;
67 if (dst_md) {
68 if (dst_md->ndims != ndims) return invalid_arguments;
69 if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides())
70 return unimplemented;
71 for (int d = 0; d < ndims; ++d) {
72 if (dst_md->dims[d] != dims[d]) return invalid_arguments;
73 }
74 } else {
75 dummy_dst_md = *src_mds[0];
76 dummy_dst_md.format_kind = format_kind::any;
77 dst_md = &dummy_dst_md;
78 }
79
80 auto desc = sum_desc_t(primitive_kind::sum, dst_md, n, scales, src_mds);
81 primitive_hashing::key_t key(
82 engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
83 auto pd = primitive_cache().get_pd(key);
84
85 if (pd) {
86 return safe_ptr_assign(
87 *sum_pd_iface, new primitive_desc_iface_t(pd, engine));
88 }
89
90 for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
91 sum_pd_t *sum_pd = nullptr;
92 if ((*s)(&sum_pd, engine, attr, dst_md, n, scales, src_mds)
93 == success) {
94 pd.reset(sum_pd);
95 CHECK(safe_ptr_assign(
96 *sum_pd_iface, new primitive_desc_iface_t(pd, engine)));
97 return status::success;
98 }
99 }
100 return unimplemented;
101}
102
103} // namespace impl
104} // namespace dnnl
105
106status_t dnnl_sum_primitive_desc_create(primitive_desc_iface_t **sum_pd_iface,
107 engine_t *engine, const memory_desc_t *dst_md, int n,
108 const float *scales, const memory_desc_t *const *src_mds,
109 const primitive_attr_t *attr) {
110 return sum_primitive_desc_create(
111 sum_pd_iface, dst_md, n, scales, src_mds, attr, engine);
112}
113