1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// ATTENTION: The code in this file is highly EXPERIMENTAL.
6// Adventurous users should note that the APIs will probably change.
7
8#include "onnx/common/assertions.h"
9
10namespace ONNX_NAMESPACE {
11
12// Intrusive doubly linked lists with sane reverse iterators.
13// The header file is named graph_node_list.h because it is ONLY
14// used for Graph's Node lists, and if you want to use it for other
15// things, you will have to do some refactoring.
16//
17// At the moment, the templated type T must support a few operations:
18//
19// - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
20// which are used for the intrusive linked list pointers.
21//
22// - It must have a method 'destroy()', which removes T from the
23// list and frees a T.
24//
25// In practice, we are only using it with Node and const Node. 'destroy()'
26// needs to be renegotiated if you want to use this somewhere else.
27//
28// Besides the benefits of being intrusive, unlike std::list, these lists handle
29// forward and backward iteration uniformly because we require a
30// "before-first-element" sentinel. This means that reverse iterators
31// physically point to the element they logically point to, rather than
32// the off-by-one behavior for all standard library reverse iterators.
33
34static constexpr size_t kNextDirection = 0;
35static constexpr size_t kPrevDirection = 1;
36
37template <typename T>
38struct generic_graph_node_list;
39
40template <typename T>
41struct generic_graph_node_list_iterator;
42
43struct Node;
44using graph_node_list = generic_graph_node_list<Node>;
45using const_graph_node_list = generic_graph_node_list<const Node>;
46using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
47using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;
48
49template <typename T>
50struct generic_graph_node_list_iterator final {
51 generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
52 generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {}
53 T* operator*() const {
54 return cur;
55 }
56 T* operator->() const {
57 return cur;
58 }
59 generic_graph_node_list_iterator& operator++() {
60 ONNX_ASSERT(cur);
61 cur = cur->next_in_graph[d];
62 return *this;
63 }
64 generic_graph_node_list_iterator operator++(int) {
65 generic_graph_node_list_iterator old = *this;
66 ++(*this);
67 return old;
68 }
69 generic_graph_node_list_iterator& operator--() {
70 ONNX_ASSERT(cur);
71 cur = cur->next_in_graph[reverseDir()];
72 return *this;
73 }
74 generic_graph_node_list_iterator operator--(int) {
75 generic_graph_node_list_iterator old = *this;
76 --(*this);
77 return old;
78 }
79
80 // erase cur without invalidating this iterator
81 // named differently from destroy so that ->/. bugs do not
82 // silently cause the wrong one to be called.
83 // iterator will point to the previous entry after call
84 void destroyCurrent() {
85 T* n = cur;
86 cur = cur->next_in_graph[reverseDir()];
87 n->destroy();
88 }
89 generic_graph_node_list_iterator reverse() {
90 return generic_graph_node_list_iterator(cur, reverseDir());
91 }
92
93 private:
94 size_t reverseDir() {
95 return d == kNextDirection ? kPrevDirection : kNextDirection;
96 }
97 T* cur;
98 size_t d; // direction 0 is forward 1 is reverse, see next_in_graph
99};
100
101template <typename T>
102struct generic_graph_node_list final {
103 using iterator = generic_graph_node_list_iterator<T>;
104 using const_iterator = generic_graph_node_list_iterator<const T>;
105 generic_graph_node_list_iterator<T> begin() {
106 return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
107 }
108 generic_graph_node_list_iterator<const T> begin() const {
109 return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
110 }
111 generic_graph_node_list_iterator<T> end() {
112 return generic_graph_node_list_iterator<T>(head, d);
113 }
114 generic_graph_node_list_iterator<const T> end() const {
115 return generic_graph_node_list_iterator<const T>(head, d);
116 }
117 generic_graph_node_list_iterator<T> rbegin() {
118 return reverse().begin();
119 }
120 generic_graph_node_list_iterator<const T> rbegin() const {
121 return reverse().begin();
122 }
123 generic_graph_node_list_iterator<T> rend() {
124 return reverse().end();
125 }
126 generic_graph_node_list_iterator<const T> rend() const {
127 return reverse().end();
128 }
129 generic_graph_node_list reverse() {
130 return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
131 }
132 const generic_graph_node_list reverse() const {
133 return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
134 }
135 generic_graph_node_list(T* head, size_t d) : head(head), d(d) {}
136
137 private:
138 T* head;
139 size_t d;
140};
141
142template <typename T>
143static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
144 return *a == *b;
145}
146
147template <typename T>
148static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
149 return *a != *b;
150}
151
152} // namespace ONNX_NAMESPACE
153
154namespace std {
155
156template <typename T>
157struct iterator_traits<ONNX_NAMESPACE::generic_graph_node_list_iterator<T>> {
158 using difference_type = int64_t;
159 using value_type = T*;
160 using pointer = T**;
161 using reference = T*&;
162 using iterator_category = bidirectional_iterator_tag;
163};
164
165} // namespace std
166