1#pragma once
2
3#include <c10/util/Exception.h>
4
5namespace torch {
6namespace jit {
7
8// Intrusive doubly linked lists with sane reverse iterators.
9// The header file is named generic_graph_node_list.h because it is ONLY
10// used for Graph's Node lists, and if you want to use it for other
11// things, you will have to do some refactoring.
12//
13// At the moment, the templated type T must support a few operations:
14//
15// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
16// which are used for the intrusive linked list pointers.
17//
18// - It must have a method 'destroy()', which removes T from the
19// list and frees a T.
20//
21// In practice, we are only using it with Node and const Node. 'destroy()'
22// needs to be renegotiated if you want to use this somewhere else.
23//
24// Regardless of the iteration direction, iterators always physically point
25// to the element they logically point to, rather than
26// the off-by-one behavior for all standard library reverse iterators like
27// std::list.
28
29// The list is includes two sentinel nodes, one at the beginning and one at the
30// end with a circular link between them. It is an error to insert nodes after
31// the end sentinel node but before the beginning node:
32
33// Visualization showing only the next() links:
34// HEAD -> first -> second -> ... -> last -> TAIL
35// ^------------------------------------------
36
37// Visualization showing only the prev() links:
38// HEAD <- first <- second <- ... <- last <- TAIL
39// ------------------------------------------^
40
41static constexpr int kNextDirection = 0;
42static constexpr int kPrevDirection = 1;
43
44template <typename T>
45struct generic_graph_node_list;
46
47template <typename T>
48struct generic_graph_node_list_iterator;
49
50struct Node;
51using graph_node_list = generic_graph_node_list<Node>;
52using const_graph_node_list = generic_graph_node_list<const Node>;
53using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
54using const_graph_node_list_iterator =
55 generic_graph_node_list_iterator<const Node>;
56
57template <typename T>
58struct generic_graph_node_list_iterator {
59 generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
60 generic_graph_node_list_iterator(T* cur, int d) : cur(cur), d(d) {}
61 generic_graph_node_list_iterator(
62 const generic_graph_node_list_iterator& rhs) = default;
63 generic_graph_node_list_iterator(generic_graph_node_list_iterator&& rhs) =
64 default;
65 generic_graph_node_list_iterator& operator=(
66 const generic_graph_node_list_iterator& rhs) = default;
67 generic_graph_node_list_iterator& operator=(
68 generic_graph_node_list_iterator&& rhs) = default;
69 T* operator*() const {
70 return cur;
71 }
72 T* operator->() const {
73 return cur;
74 }
75 generic_graph_node_list_iterator& operator++() {
76 AT_ASSERT(cur);
77 cur = cur->next_in_graph[d];
78 return *this;
79 }
80 generic_graph_node_list_iterator operator++(int) {
81 generic_graph_node_list_iterator old = *this;
82 ++(*this);
83 return old;
84 }
85 generic_graph_node_list_iterator& operator--() {
86 AT_ASSERT(cur);
87 cur = cur->next_in_graph[reverseDir()];
88 return *this;
89 }
90 generic_graph_node_list_iterator operator--(int) {
91 generic_graph_node_list_iterator old = *this;
92 --(*this);
93 return old;
94 }
95
96 // erase cur without invalidating this iterator
97 // named differently from destroy so that ->/. bugs do not
98 // silently cause the wrong one to be called.
99 // iterator will point to the previous entry after call
100 void destroyCurrent() {
101 T* n = cur;
102 cur = cur->next_in_graph[reverseDir()];
103 n->destroy();
104 }
105 generic_graph_node_list_iterator reverse() {
106 return generic_graph_node_list_iterator(cur, reverseDir());
107 }
108
109 private:
110 int reverseDir() {
111 return d == kNextDirection ? kPrevDirection : kNextDirection;
112 }
113 T* cur;
114 int d; // direction 0 is forward 1 is reverse, see next_in_graph
115};
116
117template <typename T>
118struct generic_graph_node_list {
119 using iterator = generic_graph_node_list_iterator<T>;
120 using const_iterator = generic_graph_node_list_iterator<const T>;
121 generic_graph_node_list_iterator<T> begin() {
122 return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
123 }
124 generic_graph_node_list_iterator<const T> begin() const {
125 return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
126 }
127 generic_graph_node_list_iterator<T> end() {
128 return generic_graph_node_list_iterator<T>(head->next_in_graph[!d], d);
129 }
130 generic_graph_node_list_iterator<const T> end() const {
131 return generic_graph_node_list_iterator<const T>(
132 head->next_in_graph[!d], d);
133 }
134 generic_graph_node_list_iterator<T> rbegin() {
135 return reverse().begin();
136 }
137 generic_graph_node_list_iterator<const T> rbegin() const {
138 return reverse().begin();
139 }
140 generic_graph_node_list_iterator<T> rend() {
141 return reverse().end();
142 }
143 generic_graph_node_list_iterator<const T> rend() const {
144 return reverse().end();
145 }
146 generic_graph_node_list reverse() {
147 return generic_graph_node_list(head->next_in_graph[!d], !d);
148 }
149 const generic_graph_node_list reverse() const {
150 return generic_graph_node_list(head->next_in_graph[!d], !d);
151 }
152 T* front() {
153 return head->next_in_graph[d];
154 }
155 const T* front() const {
156 return head->next_in_graph[d];
157 }
158 T* back() {
159 return head->next_in_graph[!d];
160 }
161 const T* back() const {
162 return head->next_in_graph[!d];
163 }
164 generic_graph_node_list(T* head, int d) : head(head), d(d) {}
165
166 private:
167 T* head; // both head and tail are sentinel nodes
168 // the first real node is head->next_in_graph[d]
169 // the tail sentinel is head->next_in_graph[!d]
170 int d;
171};
172
173template <typename T>
174static inline bool operator==(
175 generic_graph_node_list_iterator<T> a,
176 generic_graph_node_list_iterator<T> b) {
177 return *a == *b;
178}
179
180template <typename T>
181static inline bool operator!=(
182 generic_graph_node_list_iterator<T> a,
183 generic_graph_node_list_iterator<T> b) {
184 return *a != *b;
185}
186
187} // namespace jit
188} // namespace torch
189
190namespace std {
191
192template <typename T>
193struct iterator_traits<torch::jit::generic_graph_node_list_iterator<T>> {
194 using difference_type = int64_t;
195 using value_type = T*;
196 using pointer = T**;
197 using reference = T*&;
198 using iterator_category = bidirectional_iterator_tag;
199};
200
201} // namespace std
202