1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. |
6 | // Adventurous users should note that the APIs will probably change. |
7 | |
8 | #include "onnx/common/assertions.h" |
9 | |
10 | namespace ONNX_NAMESPACE { |
11 | |
12 | // Intrusive doubly linked lists with sane reverse iterators. |
13 | // The header file is named graph_node_list.h because it is ONLY |
14 | // used for Graph's Node lists, and if you want to use it for other |
15 | // things, you will have to do some refactoring. |
16 | // |
17 | // At the moment, the templated type T must support a few operations: |
18 | // |
19 | // - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; |
20 | // which are used for the intrusive linked list pointers. |
21 | // |
22 | // - It must have a method 'destroy()', which removes T from the |
23 | // list and frees a T. |
24 | // |
25 | // In practice, we are only using it with Node and const Node. 'destroy()' |
26 | // needs to be renegotiated if you want to use this somewhere else. |
27 | // |
28 | // Besides the benefits of being intrusive, unlike std::list, these lists handle |
29 | // forward and backward iteration uniformly because we require a |
30 | // "before-first-element" sentinel. This means that reverse iterators |
31 | // physically point to the element they logically point to, rather than |
32 | // the off-by-one behavior for all standard library reverse iterators. |
33 | |
34 | static constexpr size_t kNextDirection = 0; |
35 | static constexpr size_t kPrevDirection = 1; |
36 | |
37 | template <typename T> |
38 | struct generic_graph_node_list; |
39 | |
40 | template <typename T> |
41 | struct generic_graph_node_list_iterator; |
42 | |
43 | struct Node; |
44 | using graph_node_list = generic_graph_node_list<Node>; |
45 | using const_graph_node_list = generic_graph_node_list<const Node>; |
46 | using graph_node_list_iterator = generic_graph_node_list_iterator<Node>; |
47 | using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>; |
48 | |
49 | template <typename T> |
50 | struct generic_graph_node_list_iterator final { |
51 | generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} |
52 | generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {} |
53 | T* operator*() const { |
54 | return cur; |
55 | } |
56 | T* operator->() const { |
57 | return cur; |
58 | } |
59 | generic_graph_node_list_iterator& operator++() { |
60 | ONNX_ASSERT(cur); |
61 | cur = cur->next_in_graph[d]; |
62 | return *this; |
63 | } |
64 | generic_graph_node_list_iterator operator++(int) { |
65 | generic_graph_node_list_iterator old = *this; |
66 | ++(*this); |
67 | return old; |
68 | } |
69 | generic_graph_node_list_iterator& operator--() { |
70 | ONNX_ASSERT(cur); |
71 | cur = cur->next_in_graph[reverseDir()]; |
72 | return *this; |
73 | } |
74 | generic_graph_node_list_iterator operator--(int) { |
75 | generic_graph_node_list_iterator old = *this; |
76 | --(*this); |
77 | return old; |
78 | } |
79 | |
80 | // erase cur without invalidating this iterator |
81 | // named differently from destroy so that ->/. bugs do not |
82 | // silently cause the wrong one to be called. |
83 | // iterator will point to the previous entry after call |
84 | void destroyCurrent() { |
85 | T* n = cur; |
86 | cur = cur->next_in_graph[reverseDir()]; |
87 | n->destroy(); |
88 | } |
89 | generic_graph_node_list_iterator reverse() { |
90 | return generic_graph_node_list_iterator(cur, reverseDir()); |
91 | } |
92 | |
93 | private: |
94 | size_t reverseDir() { |
95 | return d == kNextDirection ? kPrevDirection : kNextDirection; |
96 | } |
97 | T* cur; |
98 | size_t d; // direction 0 is forward 1 is reverse, see next_in_graph |
99 | }; |
100 | |
101 | template <typename T> |
102 | struct generic_graph_node_list final { |
103 | using iterator = generic_graph_node_list_iterator<T>; |
104 | using const_iterator = generic_graph_node_list_iterator<const T>; |
105 | generic_graph_node_list_iterator<T> begin() { |
106 | return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d); |
107 | } |
108 | generic_graph_node_list_iterator<const T> begin() const { |
109 | return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d); |
110 | } |
111 | generic_graph_node_list_iterator<T> end() { |
112 | return generic_graph_node_list_iterator<T>(head, d); |
113 | } |
114 | generic_graph_node_list_iterator<const T> end() const { |
115 | return generic_graph_node_list_iterator<const T>(head, d); |
116 | } |
117 | generic_graph_node_list_iterator<T> rbegin() { |
118 | return reverse().begin(); |
119 | } |
120 | generic_graph_node_list_iterator<const T> rbegin() const { |
121 | return reverse().begin(); |
122 | } |
123 | generic_graph_node_list_iterator<T> rend() { |
124 | return reverse().end(); |
125 | } |
126 | generic_graph_node_list_iterator<const T> rend() const { |
127 | return reverse().end(); |
128 | } |
129 | generic_graph_node_list reverse() { |
130 | return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection); |
131 | } |
132 | const generic_graph_node_list reverse() const { |
133 | return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection); |
134 | } |
135 | generic_graph_node_list(T* head, size_t d) : head(head), d(d) {} |
136 | |
137 | private: |
138 | T* head; |
139 | size_t d; |
140 | }; |
141 | |
142 | template <typename T> |
143 | static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) { |
144 | return *a == *b; |
145 | } |
146 | |
147 | template <typename T> |
148 | static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) { |
149 | return *a != *b; |
150 | } |
151 | |
152 | } // namespace ONNX_NAMESPACE |
153 | |
154 | namespace std { |
155 | |
156 | template <typename T> |
157 | struct iterator_traits<ONNX_NAMESPACE::generic_graph_node_list_iterator<T>> { |
158 | using difference_type = int64_t; |
159 | using value_type = T*; |
160 | using pointer = T**; |
161 | using reference = T*&; |
162 | using iterator_category = bidirectional_iterator_tag; |
163 | }; |
164 | |
165 | } // namespace std |
166 | |