1 | #pragma once |
2 | |
3 | #include <c10/util/Exception.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | // Intrusive doubly linked lists with sane reverse iterators. |
9 | // The header file is named generic_graph_node_list.h because it is ONLY |
10 | // used for Graph's Node lists, and if you want to use it for other |
11 | // things, you will have to do some refactoring. |
12 | // |
13 | // At the moment, the templated type T must support a few operations: |
14 | // |
15 | // - It must have a field: T* next_in_graph[2] = { nullptr, nullptr }; |
16 | // which are used for the intrusive linked list pointers. |
17 | // |
18 | // - It must have a method 'destroy()', which removes T from the |
19 | // list and frees a T. |
20 | // |
21 | // In practice, we are only using it with Node and const Node. 'destroy()' |
22 | // needs to be renegotiated if you want to use this somewhere else. |
23 | // |
24 | // Regardless of the iteration direction, iterators always physically point |
25 | // to the element they logically point to, rather than |
26 | // the off-by-one behavior for all standard library reverse iterators like |
27 | // std::list. |
28 | |
29 | // The list is includes two sentinel nodes, one at the beginning and one at the |
30 | // end with a circular link between them. It is an error to insert nodes after |
31 | // the end sentinel node but before the beginning node: |
32 | |
33 | // Visualization showing only the next() links: |
34 | // HEAD -> first -> second -> ... -> last -> TAIL |
35 | // ^------------------------------------------ |
36 | |
37 | // Visualization showing only the prev() links: |
38 | // HEAD <- first <- second <- ... <- last <- TAIL |
39 | // ------------------------------------------^ |
40 | |
41 | static constexpr int kNextDirection = 0; |
42 | static constexpr int kPrevDirection = 1; |
43 | |
44 | template <typename T> |
45 | struct generic_graph_node_list; |
46 | |
47 | template <typename T> |
48 | struct generic_graph_node_list_iterator; |
49 | |
50 | struct Node; |
51 | using graph_node_list = generic_graph_node_list<Node>; |
52 | using const_graph_node_list = generic_graph_node_list<const Node>; |
53 | using graph_node_list_iterator = generic_graph_node_list_iterator<Node>; |
54 | using const_graph_node_list_iterator = |
55 | generic_graph_node_list_iterator<const Node>; |
56 | |
57 | template <typename T> |
58 | struct generic_graph_node_list_iterator { |
59 | generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {} |
60 | generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {} |
61 | generic_graph_node_list_iterator( |
62 | const generic_graph_node_list_iterator& rhs) = default; |
63 | generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) = |
64 | default; |
65 | generic_graph_node_list_iterator& operator=( |
66 | const generic_graph_node_list_iterator& rhs) = default; |
67 | generic_graph_node_list_iterator& operator=( |
68 | generic_graph_node_list_iterator&& rhs) = default; |
69 | T* operator*() const { |
70 | return cur; |
71 | } |
72 | T* operator->() const { |
73 | return cur; |
74 | } |
75 | generic_graph_node_list_iterator& operator++() { |
76 | AT_ASSERT(cur); |
77 | cur = cur->next_in_graph[d]; |
78 | return *this; |
79 | } |
80 | generic_graph_node_list_iterator operator++(int) { |
81 | generic_graph_node_list_iterator old = *this; |
82 | ++(*this); |
83 | return old; |
84 | } |
85 | generic_graph_node_list_iterator& operator--() { |
86 | AT_ASSERT(cur); |
87 | cur = cur->next_in_graph[reverseDir()]; |
88 | return *this; |
89 | } |
90 | generic_graph_node_list_iterator operator--(int) { |
91 | generic_graph_node_list_iterator old = *this; |
92 | --(*this); |
93 | return old; |
94 | } |
95 | |
96 | // erase cur without invalidating this iterator |
97 | // named differently from destroy so that ->/. bugs do not |
98 | // silently cause the wrong one to be called. |
99 | // iterator will point to the previous entry after call |
100 | void destroyCurrent() { |
101 | T* n = cur; |
102 | cur = cur->next_in_graph[reverseDir()]; |
103 | n->destroy(); |
104 | } |
105 | generic_graph_node_list_iterator reverse() { |
106 | return generic_graph_node_list_iterator(cur, reverseDir()); |
107 | } |
108 | |
109 | private: |
110 | int reverseDir() { |
111 | return d == kNextDirection ? kPrevDirection : kNextDirection; |
112 | } |
113 | T* cur; |
114 | int d; // direction 0 is forward 1 is reverse, see next_in_graph |
115 | }; |
116 | |
117 | template <typename T> |
118 | struct generic_graph_node_list { |
119 | using iterator = generic_graph_node_list_iterator<T>; |
120 | using const_iterator = generic_graph_node_list_iterator<const T>; |
121 | generic_graph_node_list_iterator<T> begin() { |
122 | return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d); |
123 | } |
124 | generic_graph_node_list_iterator<const T> begin() const { |
125 | return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d); |
126 | } |
127 | generic_graph_node_list_iterator<T> end() { |
128 | return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d); |
129 | } |
130 | generic_graph_node_list_iterator<const T> end() const { |
131 | return generic_graph_node_list_iterator<const T>( |
132 | head->next_in_graph[!d], d); |
133 | } |
134 | generic_graph_node_list_iterator<T> rbegin() { |
135 | return reverse().begin(); |
136 | } |
137 | generic_graph_node_list_iterator<const T> rbegin() const { |
138 | return reverse().begin(); |
139 | } |
140 | generic_graph_node_list_iterator<T> rend() { |
141 | return reverse().end(); |
142 | } |
143 | generic_graph_node_list_iterator<const T> rend() const { |
144 | return reverse().end(); |
145 | } |
146 | generic_graph_node_list reverse() { |
147 | return generic_graph_node_list(head->next_in_graph[!d], !d); |
148 | } |
149 | const generic_graph_node_list reverse() const { |
150 | return generic_graph_node_list(head->next_in_graph[!d], !d); |
151 | } |
152 | T* front() { |
153 | return head->next_in_graph[d]; |
154 | } |
155 | const T* front() const { |
156 | return head->next_in_graph[d]; |
157 | } |
158 | T* back() { |
159 | return head->next_in_graph[!d]; |
160 | } |
161 | const T* back() const { |
162 | return head->next_in_graph[!d]; |
163 | } |
164 | generic_graph_node_list(T* head, int d) : head(head), d(d) {} |
165 | |
166 | private: |
167 | T* head; // both head and tail are sentinel nodes |
168 | // the first real node is head->next_in_graph[d] |
169 | // the tail sentinel is head->next_in_graph[!d] |
170 | int d; |
171 | }; |
172 | |
173 | template <typename T> |
174 | static inline bool operator==( |
175 | generic_graph_node_list_iterator<T> a, |
176 | generic_graph_node_list_iterator<T> b) { |
177 | return *a == *b; |
178 | } |
179 | |
180 | template <typename T> |
181 | static inline bool operator!=( |
182 | generic_graph_node_list_iterator<T> a, |
183 | generic_graph_node_list_iterator<T> b) { |
184 | return *a != *b; |
185 | } |
186 | |
187 | } // namespace jit |
188 | } // namespace torch |
189 | |
190 | namespace std { |
191 | |
192 | template <typename T> |
193 | struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> { |
194 | using difference_type = int64_t; |
195 | using value_type = T*; |
196 | using pointer = T**; |
197 | using reference = T*&; |
198 | using iterator_category = bidirectional_iterator_tag; |
199 | }; |
200 | |
201 | } // namespace std |
202 | |