1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl.h"
18
19#include "c_types_map.hpp"
20
21#include "engine.hpp"
22#include "primitive_desc_iface.hpp"
23#include "primitive_desc_iterator.hpp"
24#include "primitive_iface.hpp"
25
26using namespace dnnl::impl;
27using namespace dnnl::impl::status;
28
29namespace dnnl {
30namespace impl {
31
32status_t primitive_desc_create(primitive_desc_iface_t **primitive_desc_iface,
33 engine_t *engine, const op_desc_t *op_desc,
34 const primitive_desc_iface_t *hint_fwd_pd,
35 const primitive_attr_t *attr) {
36 using namespace primitive_kind;
37
38 if (!primitive_desc_iface) return invalid_arguments;
39
40 const bool known_primitive_kind = utils::one_of(op_desc->kind,
41 batch_normalization, binary, convolution, deconvolution, eltwise,
42 gemm, inner_product, layer_normalization, lrn, matmul, pooling,
43 prelu, reduction, resampling, rnn, shuffle, softmax);
44 if (!known_primitive_kind) return invalid_arguments;
45
46 auto pd_iface = utils::make_unique<primitive_desc_iface_t>(engine, op_desc,
47 attr, hint_fwd_pd ? hint_fwd_pd->impl().get() : nullptr);
48 if (pd_iface == nullptr) return out_of_memory;
49 CHECK(pd_iface->init());
50
51 *primitive_desc_iface = pd_iface.release();
52
53 return success;
54}
55
56} // namespace impl
57} // namespace dnnl
58
59dnnl_primitive_desc::dnnl_primitive_desc(
60 const std::shared_ptr<primitive_desc_t> &pd, engine_t *engine)
61 : pd_(pd), engine_(engine) {}
62
63dnnl_primitive_desc::dnnl_primitive_desc(engine_t *engine,
64 const op_desc_t *op_desc, const primitive_attr_t *attr,
65 const primitive_desc_t *hint_fwd_pd) {
66
67 pd_iterator_ = utils::make_unique<primitive_desc_iterator_t>(
68 engine, op_desc, attr, hint_fwd_pd);
69}
70
71status_t dnnl_primitive_desc::init() {
72 if (!pd_iterator_) return status::out_of_memory;
73 if (!pd_iterator_->is_initialized()) return out_of_memory;
74
75 ++(*pd_iterator_);
76 if (*pd_iterator_ == pd_iterator_->end()) return unimplemented;
77
78 pd_ = *(*pd_iterator_);
79 engine_ = pd_iterator_->engine();
80
81 return success;
82}
83
84status_t dnnl_primitive_desc::next_impl() {
85 if (!pd_iterator_) return status::last_impl_reached;
86 ++(*pd_iterator_);
87 if (*pd_iterator_ == pd_iterator_->end()) return last_impl_reached;
88 pd_ = *(*pd_iterator_);
89 return status::success;
90}
91
92status_t dnnl_primitive_desc::create_primitive_iface(
93 std::pair<primitive_iface_t *, bool> &primitive_iface,
94 const cache_blob_t &cache_blob) const {
95 // Step 1: create impl::primitive_t or get it from primitive cache
96 std::pair<std::shared_ptr<primitive_t>, bool> p;
97 auto status = impl()->create_primitive(p, engine(), cache_blob);
98 if (status != status::success) return status;
99 // Step 2: create primitive_iface_t, init and return it to user
100 primitive_iface_t *p_iface = nullptr;
101 CHECK(safe_ptr_assign(p_iface, new primitive_iface_t(p.first, engine())));
102 status = p_iface->init();
103 if (status != status::success) {
104 p_iface->release();
105 return status;
106 }
107 primitive_iface = std::make_pair(p_iface, p.second);
108 return status::success;
109}
110
111const std::shared_ptr<primitive_desc_t> &dnnl_primitive_desc::impl() const {
112 return pd_;
113}
114
115dnnl::impl::engine_t *dnnl_primitive_desc::engine() const {
116 return engine_;
117}
118const dnnl::impl::primitive_attr_t *dnnl_primitive_desc::attr() const {
119 return impl()->attr();
120}
121
122const char *dnnl_primitive_desc::info() const {
123 return impl()->info(engine_);
124}
125
126dnnl::impl::engine_t *dnnl_primitive_desc::src_engine() const {
127 return engine();
128}
129dnnl::impl::engine_t *dnnl_primitive_desc::dst_engine() const {
130 return engine();
131}
132
133dnnl::impl::engine_t *dnnl_primitive_desc::scratchpad_engine() const {
134 return engine();
135}
136
137status_t dnnl_primitive_desc::query(query_t what, int idx, void *result) const {
138 auto status = status::success;
139 switch (what) {
140 case query::engine: *(engine_t **)result = engine(); break;
141 case query::cache_blob_id_size_s64:
142 *(dim_t *)result
143 = (dim_t)impl()->get_cache_blob_id(engine()).size();
144 break;
145 case query::cache_blob_id:
146 *(const uint8_t **)result
147 = impl()->get_cache_blob_id(engine()).empty()
148 ? nullptr
149 : impl()->get_cache_blob_id(engine()).data();
150 break;
151
152 default: status = impl()->query(what, idx, result);
153 }
154 return status;
155}
156
157status_t dnnl_primitive_desc_get_attr(
158 const primitive_desc_iface_t *primitive_desc_iface,
159 const primitive_attr_t **attr) {
160 if (utils::any_null(primitive_desc_iface, attr)) return invalid_arguments;
161
162 *attr = primitive_desc_iface->attr();
163 return success;
164}
165
166status_t dnnl_primitive_desc_clone(
167 primitive_desc_iface_t **primitive_desc_iface,
168 const primitive_desc_iface_t *existing_primitive_desc_iface) {
169 if (utils::any_null(primitive_desc_iface, existing_primitive_desc_iface))
170 return invalid_arguments;
171
172 return safe_ptr_assign(*primitive_desc_iface,
173 new primitive_desc_iface_t(existing_primitive_desc_iface->impl(),
174 existing_primitive_desc_iface->engine()));
175}
176
177status_t dnnl_primitive_desc_destroy(
178 primitive_desc_iface_t *primitive_desc_iface) {
179 delete primitive_desc_iface;
180 return success;
181}
182
183status_t dnnl_primitive_desc_next_impl(
184 primitive_desc_iface_t *primitive_desc_iface) {
185 if (!primitive_desc_iface) return invalid_arguments;
186 return primitive_desc_iface->next_impl();
187}
188