1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_ITERATOR_HPP |
18 | #define COMMON_PRIMITIVE_ITERATOR_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "engine.hpp" |
24 | #include "impl_list_item.hpp" |
25 | #include "primitive_attr.hpp" |
26 | #include "primitive_cache.hpp" |
27 | #include "primitive_hashing.hpp" |
28 | #include "type_helpers.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | |
33 | struct primitive_desc_t; |
34 | |
35 | struct primitive_desc_iterator_t : public c_compatible { |
36 | primitive_desc_iterator_t(engine_t *engine, const op_desc_t *op_desc, |
37 | const primitive_attr_t *attr, const primitive_desc_t *hint_fwd_pd, |
38 | int skip_idx = -1) |
39 | : idx_(-1) |
40 | , engine_(engine) |
41 | , op_desc_(nullptr) |
42 | , attr_(attr ? *attr : primitive_attr_t()) |
43 | , hint_fwd_pd_(hint_fwd_pd) |
44 | , impl_list_(nullptr) |
45 | , last_idx_(0) |
46 | , skip_idx_(skip_idx) |
47 | , offset_(-1) { |
48 | |
49 | op_desc_ = (op_desc_t *)std::malloc(sizeof(op_desc_t)); |
50 | copy_c_op_desc(op_desc_, op_desc); |
51 | |
52 | impl_list_ = engine_->get_implementation_list(op_desc_); |
53 | |
54 | while (impl_list_[last_idx_]) |
55 | ++last_idx_; |
56 | is_initialized_ = is_initialized_ && attr_.is_initialized(); |
57 | } |
58 | |
59 | ~primitive_desc_iterator_t() { std::free(op_desc_); } |
60 | |
61 | engine_t *engine() const { return engine_; } |
62 | |
63 | bool operator==(const primitive_desc_iterator_t &rhs) const { |
64 | return idx_ == rhs.idx_ && engine_ == rhs.engine_; |
65 | } |
66 | bool operator!=(const primitive_desc_iterator_t &rhs) const { |
67 | return !operator==(rhs); |
68 | } |
69 | |
70 | primitive_desc_iterator_t end() const { |
71 | return primitive_desc_iterator_t(engine_, last_idx_); |
72 | } |
73 | |
74 | primitive_desc_iterator_t &operator++() { |
75 | // Quick return to preserve state of the iterator that reached the end. |
76 | // The state is equal to the state of the iterator that end() returns. |
77 | if (idx_ == last_idx_) return *this; |
78 | |
79 | offset_++; |
80 | pd_.reset(); |
81 | |
82 | std::vector<dnnl::impl::memory_desc_t> hint_mds; |
83 | if (hint_fwd_pd_) hint_mds = hint_fwd_pd_->hint_mds(true /* is_hint */); |
84 | primitive_hashing::key_t key( |
85 | engine_, op_desc_, &attr_, offset_, hint_mds); |
86 | |
87 | pd_ = primitive_cache().get_pd(key); |
88 | if (pd_) { return *this; } |
89 | |
90 | while (++idx_ != last_idx_) { |
91 | if (idx_ == skip_idx_) continue; |
92 | primitive_desc_t *candidate_pd = nullptr; |
93 | auto s = impl_list_[idx_](&candidate_pd, op_desc_, &attr_, engine_, |
94 | hint_fwd_pd_, offset_); |
95 | if (s == status::success) { |
96 | pd_.reset(candidate_pd); |
97 | break; |
98 | } |
99 | } |
100 | return *this; |
101 | } |
102 | |
103 | const std::shared_ptr<primitive_desc_t> &operator*() const { return pd_; } |
104 | |
105 | const primitive_attr_t &attr() const { return attr_; } |
106 | |
107 | bool is_initialized() const { return is_initialized_; } |
108 | |
109 | protected: |
110 | int idx_; |
111 | engine_t *engine_; |
112 | std::shared_ptr<primitive_desc_t> pd_; |
113 | op_desc_t *op_desc_; |
114 | const primitive_attr_t attr_; |
115 | const primitive_desc_t *hint_fwd_pd_; |
116 | const impl_list_item_t *impl_list_; |
117 | int last_idx_; |
118 | int skip_idx_; |
119 | int offset_; |
120 | |
121 | private: |
122 | primitive_desc_iterator_t(engine_t *engine, int last_idx) |
123 | : idx_(last_idx) |
124 | , engine_(engine) |
125 | , op_desc_(nullptr) |
126 | , hint_fwd_pd_(nullptr) |
127 | , impl_list_(nullptr) |
128 | , last_idx_(last_idx) |
129 | , skip_idx_(-1) |
130 | , offset_(-1) {} |
131 | |
132 | primitive_desc_iterator_t(primitive_desc_iterator_t &&other) |
133 | : idx_(other.idx_) |
134 | , engine_(other.engine_) |
135 | , pd_(std::move(other.pd_)) |
136 | , op_desc_(other.op_desc_) |
137 | , attr_(other.attr_) |
138 | , hint_fwd_pd_(other.hint_fwd_pd_) |
139 | , impl_list_(other.impl_list_) |
140 | , skip_idx_(other.skip_idx_) |
141 | , offset_(other.offset_) {} |
142 | |
143 | DNNL_DISALLOW_COPY_AND_ASSIGN(primitive_desc_iterator_t); |
144 | }; |
145 | |
146 | } // namespace impl |
147 | } // namespace dnnl |
148 | |
149 | #endif |
150 | |