1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_PRIMITIVE_ITERATOR_HPP
18#define COMMON_PRIMITIVE_ITERATOR_HPP
19
20#include "oneapi/dnnl/dnnl.h"
21
22#include "c_types_map.hpp"
23#include "engine.hpp"
24#include "impl_list_item.hpp"
25#include "primitive_attr.hpp"
26#include "primitive_cache.hpp"
27#include "primitive_hashing.hpp"
28#include "type_helpers.hpp"
29
30namespace dnnl {
31namespace impl {
32
33struct primitive_desc_t;
34
35struct primitive_desc_iterator_t : public c_compatible {
36 primitive_desc_iterator_t(engine_t *engine, const op_desc_t *op_desc,
37 const primitive_attr_t *attr, const primitive_desc_t *hint_fwd_pd,
38 int skip_idx = -1)
39 : idx_(-1)
40 , engine_(engine)
41 , op_desc_(nullptr)
42 , attr_(attr ? *attr : primitive_attr_t())
43 , hint_fwd_pd_(hint_fwd_pd)
44 , impl_list_(nullptr)
45 , last_idx_(0)
46 , skip_idx_(skip_idx)
47 , offset_(-1) {
48
49 op_desc_ = (op_desc_t *)std::malloc(sizeof(op_desc_t));
50 copy_c_op_desc(op_desc_, op_desc);
51
52 impl_list_ = engine_->get_implementation_list(op_desc_);
53
54 while (impl_list_[last_idx_])
55 ++last_idx_;
56 is_initialized_ = is_initialized_ && attr_.is_initialized();
57 }
58
59 ~primitive_desc_iterator_t() { std::free(op_desc_); }
60
61 engine_t *engine() const { return engine_; }
62
63 bool operator==(const primitive_desc_iterator_t &rhs) const {
64 return idx_ == rhs.idx_ && engine_ == rhs.engine_;
65 }
66 bool operator!=(const primitive_desc_iterator_t &rhs) const {
67 return !operator==(rhs);
68 }
69
70 primitive_desc_iterator_t end() const {
71 return primitive_desc_iterator_t(engine_, last_idx_);
72 }
73
74 primitive_desc_iterator_t &operator++() {
75 // Quick return to preserve state of the iterator that reached the end.
76 // The state is equal to the state of the iterator that end() returns.
77 if (idx_ == last_idx_) return *this;
78
79 offset_++;
80 pd_.reset();
81
82 std::vector<dnnl::impl::memory_desc_t> hint_mds;
83 if (hint_fwd_pd_) hint_mds = hint_fwd_pd_->hint_mds(true /* is_hint */);
84 primitive_hashing::key_t key(
85 engine_, op_desc_, &attr_, offset_, hint_mds);
86
87 pd_ = primitive_cache().get_pd(key);
88 if (pd_) { return *this; }
89
90 while (++idx_ != last_idx_) {
91 if (idx_ == skip_idx_) continue;
92 primitive_desc_t *candidate_pd = nullptr;
93 auto s = impl_list_[idx_](&candidate_pd, op_desc_, &attr_, engine_,
94 hint_fwd_pd_, offset_);
95 if (s == status::success) {
96 pd_.reset(candidate_pd);
97 break;
98 }
99 }
100 return *this;
101 }
102
103 const std::shared_ptr<primitive_desc_t> &operator*() const { return pd_; }
104
105 const primitive_attr_t &attr() const { return attr_; }
106
107 bool is_initialized() const { return is_initialized_; }
108
109protected:
110 int idx_;
111 engine_t *engine_;
112 std::shared_ptr<primitive_desc_t> pd_;
113 op_desc_t *op_desc_;
114 const primitive_attr_t attr_;
115 const primitive_desc_t *hint_fwd_pd_;
116 const impl_list_item_t *impl_list_;
117 int last_idx_;
118 int skip_idx_;
119 int offset_;
120
121private:
122 primitive_desc_iterator_t(engine_t *engine, int last_idx)
123 : idx_(last_idx)
124 , engine_(engine)
125 , op_desc_(nullptr)
126 , hint_fwd_pd_(nullptr)
127 , impl_list_(nullptr)
128 , last_idx_(last_idx)
129 , skip_idx_(-1)
130 , offset_(-1) {}
131
132 primitive_desc_iterator_t(primitive_desc_iterator_t &&other)
133 : idx_(other.idx_)
134 , engine_(other.engine_)
135 , pd_(std::move(other.pd_))
136 , op_desc_(other.op_desc_)
137 , attr_(other.attr_)
138 , hint_fwd_pd_(other.hint_fwd_pd_)
139 , impl_list_(other.impl_list_)
140 , skip_idx_(other.skip_idx_)
141 , offset_(other.offset_) {}
142
143 DNNL_DISALLOW_COPY_AND_ASSIGN(primitive_desc_iterator_t);
144};
145
146} // namespace impl
147} // namespace dnnl
148
149#endif
150