1 | /******************************************************************************* |
2 | * Copyright 2018-2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_PRIMITIVE_ITERATOR_HPP |
18 | #define COMMON_PRIMITIVE_ITERATOR_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #include "c_types_map.hpp" |
23 | #include "engine.hpp" |
24 | #include "impl_list_item.hpp" |
25 | #include "primitive_attr.hpp" |
26 | #include "primitive_cache.hpp" |
27 | #include "primitive_desc.hpp" |
28 | #include "primitive_hashing.hpp" |
29 | #include "type_helpers.hpp" |
30 | |
31 | struct dnnl_primitive_desc_iterator : public dnnl::impl::c_compatible { |
32 | dnnl_primitive_desc_iterator(dnnl::impl::engine_t *engine, |
33 | const dnnl::impl::op_desc_t *op_desc, |
34 | const dnnl::impl::primitive_attr_t *attr, |
35 | const dnnl::impl::primitive_desc_t *hint_fwd_pd, int skip_idx = -1) |
36 | : idx_(-1) |
37 | , engine_(engine) |
38 | , op_desc_(op_desc) |
39 | , attr_(attr ? *attr : dnnl::impl::primitive_attr_t()) |
40 | , hint_fwd_pd_(hint_fwd_pd) |
41 | , impl_list_(engine_->get_implementation_list(op_desc_)) |
42 | , last_idx_(0) |
43 | , skip_idx_(skip_idx) |
44 | , offset_(-1) { |
45 | while (impl_list_[last_idx_]) |
46 | ++last_idx_; |
47 | is_initialized_ = is_initialized_ && attr_.is_initialized(); |
48 | } |
49 | |
50 | dnnl::impl::engine_t *engine() const { return engine_; } |
51 | |
52 | bool operator==(const dnnl::impl::primitive_desc_iterator_t &rhs) const { |
53 | return idx_ == rhs.idx_ && engine_ == rhs.engine_; |
54 | } |
55 | bool operator!=(const dnnl::impl::primitive_desc_iterator_t &rhs) const { |
56 | return !operator==(rhs); |
57 | } |
58 | |
59 | dnnl::impl::primitive_desc_iterator_t end() const { |
60 | return dnnl_primitive_desc_iterator(engine_, last_idx_); |
61 | } |
62 | |
63 | dnnl::impl::primitive_desc_iterator_t &operator++() { |
64 | // Quick return to preserve state of the iterator that reached the end. |
65 | // The state is equal to the state of the iterator that end() returns. |
66 | if (idx_ == last_idx_) return *this; |
67 | |
68 | offset_++; |
69 | pd_.reset(); |
70 | |
71 | std::vector<dnnl::impl::memory_desc_t> hint_mds; |
72 | if (hint_fwd_pd_) hint_mds = hint_fwd_pd_->hint_mds(true /* is_hint */); |
73 | dnnl::impl::primitive_hashing::key_t key( |
74 | engine_, op_desc_, &attr_, offset_, hint_mds); |
75 | |
76 | pd_ = dnnl::impl::primitive_cache().get_pd(key); |
77 | if (pd_) { return *this; } |
78 | |
79 | while (++idx_ != last_idx_) { |
80 | if (idx_ == skip_idx_) continue; |
81 | dnnl::impl::primitive_desc_t *candidate_pd = nullptr; |
82 | auto s = impl_list_[idx_](&candidate_pd, op_desc_, &attr_, engine_, |
83 | hint_fwd_pd_, offset_); |
84 | if (s == dnnl::impl::status::success) { |
85 | pd_.reset(candidate_pd); |
86 | break; |
87 | } |
88 | } |
89 | return *this; |
90 | } |
91 | |
92 | std::shared_ptr<dnnl::impl::primitive_desc_t> operator*() const { |
93 | if (*this == end() || pd_ == nullptr) return nullptr; |
94 | return pd_; |
95 | } |
96 | |
97 | const dnnl::impl::primitive_attr_t &attr() const { return attr_; } |
98 | |
99 | bool is_initialized() const { return is_initialized_; } |
100 | |
101 | protected: |
102 | int idx_; |
103 | dnnl::impl::engine_t *engine_; |
104 | std::shared_ptr<dnnl::impl::primitive_desc_t> pd_; |
105 | const dnnl::impl::op_desc_t *op_desc_; |
106 | const dnnl::impl::primitive_attr_t attr_; |
107 | const dnnl::impl::primitive_desc_t *hint_fwd_pd_; |
108 | const dnnl::impl::impl_list_item_t *impl_list_; |
109 | int last_idx_; |
110 | int skip_idx_; |
111 | int offset_; |
112 | |
113 | private: |
114 | dnnl_primitive_desc_iterator(dnnl::impl::engine_t *engine, int last_idx) |
115 | : idx_(last_idx) |
116 | , engine_(engine) |
117 | , op_desc_(nullptr) |
118 | , hint_fwd_pd_(nullptr) |
119 | , impl_list_(nullptr) |
120 | , last_idx_(last_idx) |
121 | , skip_idx_(-1) |
122 | , offset_(-1) {} |
123 | |
124 | dnnl_primitive_desc_iterator(dnnl_primitive_desc_iterator &&other) |
125 | : idx_(other.idx_) |
126 | , engine_(other.engine_) |
127 | , pd_(std::move(other.pd_)) |
128 | , op_desc_(other.op_desc_) |
129 | , attr_(other.attr_) |
130 | , hint_fwd_pd_(other.hint_fwd_pd_) |
131 | , impl_list_(other.impl_list_) |
132 | , skip_idx_(other.skip_idx_) |
133 | , offset_(other.offset_) {} |
134 | |
135 | DNNL_DISALLOW_COPY_AND_ASSIGN(dnnl_primitive_desc_iterator); |
136 | }; |
137 | |
138 | #endif |
139 | |