1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_DESC_IFACE_HPP
18#define COMMON_PRIMITIVE_DESC_IFACE_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "cache_blob.hpp"
24#include "primitive_desc_iterator.hpp"
25
26namespace dnnl {
27namespace impl {
28
29status_t primitive_desc_create(primitive_desc_iface_t **primitive_desc_iface,
30 engine_t *engine, const op_desc_t *op_desc,
31 const primitive_desc_iface_t *hint_fwd_pd,
32 const primitive_attr_t *attr);
33}
34} // namespace dnnl
35
36// dnnl_primitive_desc is a user facing entity that has an alias
37// primitive_desc_iface_t for internal use.
38// The primitive_desc_iface_t is responsible for holding:
39// 1. impl::primitive_desc_t - a primitive descriptor implementation that
40// can be stored in the primitive cache as part of the primitive implementation
41// to which it belongs
42// 2. engine_t - a dnnl engine
43struct dnnl_primitive_desc : public dnnl::impl::c_compatible {
44 dnnl_primitive_desc(const std::shared_ptr<dnnl::impl::primitive_desc_t> &pd,
45 dnnl::impl::engine_t *engine);
46
47 dnnl_primitive_desc(dnnl::impl::engine_t *engine,
48 const dnnl::impl::op_desc_t *op_desc,
49 const dnnl::impl::primitive_attr_t *attr,
50 const dnnl::impl::primitive_desc_t *hint_fwd_pd);
51
52 virtual ~dnnl_primitive_desc() = default;
53
54 dnnl::impl::status_t init();
55 dnnl::impl::status_t next_impl();
56 const char *info() const;
57 dnnl::impl::engine_t *engine() const;
58 const dnnl::impl::primitive_attr_t *attr() const;
59 virtual dnnl::impl::engine_t *scratchpad_engine() const;
60
61 virtual dnnl::impl::engine_t *src_engine() const;
62 virtual dnnl::impl::engine_t *dst_engine() const;
63
64 virtual dnnl::impl::status_t query(
65 dnnl::impl::query_t what, int idx, void *result) const;
66
67 virtual dnnl::impl::status_t create_primitive_iface(
68 std::pair<primitive_iface_t *, bool> &primitive_iface,
69 const dnnl::impl::cache_blob_t &cache_blob) const;
70
71 const std::shared_ptr<dnnl::impl::primitive_desc_t> &impl() const;
72
73protected:
74 std::unique_ptr<dnnl::impl::primitive_desc_iterator_t> pd_iterator_;
75 // TODO: Extend iterator to support concat, sum and reorder primitives.
76 // Until it's done we need to have primitive descriptor (`pd_`) and
77 // engine (engine_) here.
78 std::shared_ptr<dnnl::impl::primitive_desc_t> pd_;
79 dnnl::impl::engine_t *engine_;
80};
81
82#endif
83