1/*******************************************************************************
2* Copyright 2018-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_ITERATOR_HPP
18#define COMMON_PRIMITIVE_ITERATOR_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "engine.hpp"
24#include "impl_list_item.hpp"
25#include "primitive_attr.hpp"
26#include "primitive_cache.hpp"
27#include "primitive_desc.hpp"
28#include "primitive_hashing.hpp"
29#include "type_helpers.hpp"
30
31struct dnnl_primitive_desc_iterator : public dnnl::impl::c_compatible {
32 dnnl_primitive_desc_iterator(dnnl::impl::engine_t *engine,
33 const dnnl::impl::op_desc_t *op_desc,
34 const dnnl::impl::primitive_attr_t *attr,
35 const dnnl::impl::primitive_desc_t *hint_fwd_pd, int skip_idx = -1)
36 : idx_(-1)
37 , engine_(engine)
38 , op_desc_(op_desc)
39 , attr_(attr ? *attr : dnnl::impl::primitive_attr_t())
40 , hint_fwd_pd_(hint_fwd_pd)
41 , impl_list_(engine_->get_implementation_list(op_desc_))
42 , last_idx_(0)
43 , skip_idx_(skip_idx)
44 , offset_(-1) {
45 while (impl_list_[last_idx_])
46 ++last_idx_;
47 is_initialized_ = is_initialized_ && attr_.is_initialized();
48 }
49
50 dnnl::impl::engine_t *engine() const { return engine_; }
51
52 bool operator==(const dnnl::impl::primitive_desc_iterator_t &rhs) const {
53 return idx_ == rhs.idx_ && engine_ == rhs.engine_;
54 }
55 bool operator!=(const dnnl::impl::primitive_desc_iterator_t &rhs) const {
56 return !operator==(rhs);
57 }
58
59 dnnl::impl::primitive_desc_iterator_t end() const {
60 return dnnl_primitive_desc_iterator(engine_, last_idx_);
61 }
62
63 dnnl::impl::primitive_desc_iterator_t &operator++() {
64 // Quick return to preserve state of the iterator that reached the end.
65 // The state is equal to the state of the iterator that end() returns.
66 if (idx_ == last_idx_) return *this;
67
68 offset_++;
69 pd_.reset();
70
71 std::vector<dnnl::impl::memory_desc_t> hint_mds;
72 if (hint_fwd_pd_) hint_mds = hint_fwd_pd_->hint_mds(true /* is_hint */);
73 dnnl::impl::primitive_hashing::key_t key(
74 engine_, op_desc_, &attr_, offset_, hint_mds);
75
76 pd_ = dnnl::impl::primitive_cache().get_pd(key);
77 if (pd_) { return *this; }
78
79 while (++idx_ != last_idx_) {
80 if (idx_ == skip_idx_) continue;
81 dnnl::impl::primitive_desc_t *candidate_pd = nullptr;
82 auto s = impl_list_[idx_](&candidate_pd, op_desc_, &attr_, engine_,
83 hint_fwd_pd_, offset_);
84 if (s == dnnl::impl::status::success) {
85 pd_.reset(candidate_pd);
86 break;
87 }
88 }
89 return *this;
90 }
91
92 std::shared_ptr<dnnl::impl::primitive_desc_t> operator*() const {
93 if (*this == end() || pd_ == nullptr) return nullptr;
94 return pd_;
95 }
96
97 const dnnl::impl::primitive_attr_t &attr() const { return attr_; }
98
99 bool is_initialized() const { return is_initialized_; }
100
101protected:
102 int idx_;
103 dnnl::impl::engine_t *engine_;
104 std::shared_ptr<dnnl::impl::primitive_desc_t> pd_;
105 const dnnl::impl::op_desc_t *op_desc_;
106 const dnnl::impl::primitive_attr_t attr_;
107 const dnnl::impl::primitive_desc_t *hint_fwd_pd_;
108 const dnnl::impl::impl_list_item_t *impl_list_;
109 int last_idx_;
110 int skip_idx_;
111 int offset_;
112
113private:
114 dnnl_primitive_desc_iterator(dnnl::impl::engine_t *engine, int last_idx)
115 : idx_(last_idx)
116 , engine_(engine)
117 , op_desc_(nullptr)
118 , hint_fwd_pd_(nullptr)
119 , impl_list_(nullptr)
120 , last_idx_(last_idx)
121 , skip_idx_(-1)
122 , offset_(-1) {}
123
124 dnnl_primitive_desc_iterator(dnnl_primitive_desc_iterator &&other)
125 : idx_(other.idx_)
126 , engine_(other.engine_)
127 , pd_(std::move(other.pd_))
128 , op_desc_(other.op_desc_)
129 , attr_(other.attr_)
130 , hint_fwd_pd_(other.hint_fwd_pd_)
131 , impl_list_(other.impl_list_)
132 , skip_idx_(other.skip_idx_)
133 , offset_(other.offset_) {}
134
135 DNNL_DISALLOW_COPY_AND_ASSIGN(dnnl_primitive_desc_iterator);
136};
137
138#endif
139