1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | // ATTENTION: The code in this file is highly EXPERIMENTAL. |
6 | // Adventurous users should note that the APIs will probably change. |
7 | |
8 | #pragma once |
9 | |
10 | #include <stdint.h> |
11 | #include <algorithm> |
12 | #include <atomic> |
13 | #include <cstdint> |
14 | #include <functional> |
15 | #include <iostream> |
16 | #include <memory> |
17 | #include <sstream> |
18 | #include <string> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "onnx/common/array_ref.h" |
23 | #include "onnx/common/assertions.h" |
24 | #include "onnx/common/common.h" |
25 | #include "onnx/common/graph_node_list.h" |
26 | #include "onnx/common/interned_strings.h" |
27 | #include "onnx/common/tensor.h" |
28 | #include "onnx/string_utils.h" |
29 | |
30 | #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \ |
31 | TypeName(const TypeName&) = delete; \ |
32 | TypeName& operator=(const TypeName&) = delete |
33 | |
34 | namespace ONNX_NAMESPACE { |
35 | |
36 | // Graph represents one "function" of computation. |
37 | // It uses a simple ownership model where the graph owns all the nodes inside it. |
38 | // All references inside the graph are raw pointers. |
39 | // Destroying the Graph will invalidate any pointers to nodes in the graph. |
40 | struct Graph; |
41 | |
42 | // Node is the base class of the IR graph. It represents one computation |
43 | // and dependencies on a list of Values. The "prim-ops", so to speak. |
44 | struct Node; |
45 | |
46 | // A Value represents an input or output to node that is either a |
47 | // Tensor or an opaque Handle object, as determined by type(). |
48 | struct Value; |
49 | |
50 | class ResourceGuard final { |
51 | std::function<void()> destructor_; |
52 | bool released_; |
53 | |
54 | public: |
55 | ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard); |
56 | explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {} |
57 | ResourceGuard(ResourceGuard&& other) = default; |
58 | ResourceGuard& operator=(ResourceGuard&& other) = default; |
59 | |
60 | ~ResourceGuard() { |
61 | if (!released_) |
62 | destructor_(); |
63 | } |
64 | |
65 | void release() { |
66 | released_ = true; |
67 | } |
68 | }; |
69 | |
70 | struct Dimension final { |
71 | Dimension() : is_unknown(true), is_int(false), dim(-1) {} |
72 | Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {} // NOLINT |
73 | Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {} // NOLINT |
74 | |
75 | bool is_unknown; |
76 | bool is_int; |
77 | int64_t dim; |
78 | std::string param; |
79 | }; |
80 | |
81 | enum class AttributeKind : uint8_t { |
82 | // float, float list, int, int list, string, string list, |
83 | // tensor, tensor list, subgraph, subgraph list. type proto, type proto list |
84 | f, |
85 | fs, |
86 | i, |
87 | is, |
88 | s, |
89 | ss, |
90 | t, |
91 | ts, |
92 | g, |
93 | gs, |
94 | tp, |
95 | tps |
96 | }; |
97 | |
98 | static inline const char* toString(AttributeKind kind) { |
99 | static constexpr const char* names[] = {"f" , "fs" , "i" , "is" , "s" , "ss" , "t" , "ts" , "g" , "gs" , "tp" , "tps" }; |
100 | ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*)); |
101 | return names[int(kind)]; |
102 | } |
103 | |
104 | struct AttributeValue { |
105 | explicit AttributeValue(Symbol name) : name(name) {} |
106 | using Ptr = std::unique_ptr<AttributeValue>; |
107 | Symbol name; |
108 | virtual AttributeKind kind() const = 0; |
109 | virtual Ptr clone() const = 0; |
110 | virtual ~AttributeValue() = default; |
111 | }; |
112 | |
113 | template <typename T, AttributeKind Kind> |
114 | struct ScalarAttributeValue final : public AttributeValue { |
115 | using ConstructorType = const T&; |
116 | using ValueType = T; |
117 | ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {} |
118 | ValueType& value() { |
119 | return value_; |
120 | } |
121 | virtual Ptr clone() const override { |
122 | return Ptr(new ScalarAttributeValue(name, value_)); |
123 | } |
124 | virtual AttributeKind kind() const override { |
125 | return Kind; |
126 | } |
127 | |
128 | private: |
129 | ValueType value_; |
130 | }; |
131 | |
132 | template <typename T, AttributeKind Kind> |
133 | struct VectorAttributeValue final : public AttributeValue { |
134 | using ConstructorType = const std::vector<T>&&; |
135 | using ValueType = std::vector<T>; |
136 | VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} |
137 | ValueType& value() { |
138 | return value_; |
139 | } |
140 | virtual AttributeKind kind() const override { |
141 | return Kind; |
142 | } |
143 | virtual std::unique_ptr<AttributeValue> clone() const override { |
144 | auto copy = value_; |
145 | return Ptr(new VectorAttributeValue(name, std::move(copy))); |
146 | } |
147 | |
148 | private: |
149 | ValueType value_; |
150 | }; |
151 | |
152 | using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>; |
153 | using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>; |
154 | using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>; |
155 | using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>; |
156 | using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>; |
157 | using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>; |
158 | using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>; |
159 | using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>; |
160 | using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>; |
161 | using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>; |
162 | using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>; |
163 | using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>; |
164 | |
165 | // CRTP so that Node which inherits Attributes can be return for |
166 | // method chaining e.g: |
167 | // Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5); |
168 | // we return Derived* pointers because Nodes are normally held as pointers. |
169 | template <typename Derived> |
170 | struct Attributes { |
171 | Attributes() {} |
172 | void copyAttributes(const Attributes& rhs) { |
173 | values_.clear(); |
174 | values_.reserve(rhs.values_.size()); |
175 | for (auto& i : rhs.values_) { |
176 | values_.push_back(i->clone()); |
177 | } |
178 | } |
179 | bool hasAttribute(Symbol name) const { |
180 | return find(name, false) != values_.end(); |
181 | } |
182 | AttributeKind kindOf(Symbol name) const { |
183 | return (*find(name, true))->kind(); |
184 | } |
185 | Derived* removeAttribute(Symbol name) { |
186 | values_.erase(find(name, true)); |
187 | return This(); |
188 | } |
189 | bool hasAttributes() const { |
190 | return !values_.empty(); |
191 | } |
192 | // The names are returned in order, since name actually is the index. |
193 | std::vector<Symbol> attributeNames() const { |
194 | std::vector<Symbol> names; |
195 | names.reserve(values_.size()); |
196 | for (auto& a : values_) |
197 | names.push_back(a->name); |
198 | return names; |
199 | } |
200 | |
201 | #define CREATE_ACCESSOR(Kind, method) \ |
202 | Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ |
203 | return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \ |
204 | } \ |
205 | const Kind##Attr::ValueType& method(Symbol name) const { \ |
206 | return get<Kind##Attr>(name); \ |
207 | } |
208 | CREATE_ACCESSOR(Float, f) |
209 | CREATE_ACCESSOR(Floats, fs) |
210 | CREATE_ACCESSOR(String, s) |
211 | CREATE_ACCESSOR(Strings, ss) |
212 | CREATE_ACCESSOR(Int, i) |
213 | CREATE_ACCESSOR(Ints, is) |
214 | CREATE_ACCESSOR(Tensor, t) |
215 | CREATE_ACCESSOR(Tensors, ts) |
216 | CREATE_ACCESSOR(Graph, g) |
217 | CREATE_ACCESSOR(Graphs, gs) |
218 | CREATE_ACCESSOR(TypeProto, tp) |
219 | CREATE_ACCESSOR(TypeProtos, tps) |
220 | |
221 | #undef CREATE_ACCESSOR |
222 | |
223 | private: |
224 | Derived* This() { |
225 | return static_cast<Derived*>(this); |
226 | } |
227 | template <typename T> |
228 | Derived* set(Symbol name, typename T::ConstructorType v) { |
229 | auto it = find(name, false); |
230 | auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v))); |
231 | if (it == values_.end()) { |
232 | values_.push_back(std::move(nv)); |
233 | } else { |
234 | *it = std::move(nv); |
235 | } |
236 | return This(); |
237 | } |
238 | template <typename T> |
239 | typename T::ValueType& get(Symbol name) const { |
240 | auto it = find(name, true); |
241 | T* child = static_cast<T*>(it->get()); |
242 | return child->value(); |
243 | } |
244 | using AVPtr = AttributeValue::Ptr; |
245 | // NB: For determinism, we use a vector rather than a hash map. This does |
246 | // mean that lookups are O(n), so you shouldn't use Attributes to store |
247 | // a big pile of messages. |
248 | std::vector<AVPtr> values_; |
249 | using iterator = std::vector<AVPtr>::iterator; |
250 | iterator find(Symbol name, bool required) { |
251 | auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); |
252 | ONNX_ASSERT(!required || it != values_.end()); |
253 | return it; |
254 | } |
255 | using const_iterator = std::vector<AVPtr>::const_iterator; |
256 | const_iterator find(Symbol name, bool required) const { |
257 | auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; }); |
258 | ONNX_ASSERTM( |
259 | !required || it != values_.end(), |
260 | "%s:%u: %s: required undefined attribute '%s'" , |
261 | __FILE__, |
262 | __LINE__, |
263 | __func__, |
264 | name.toString()); |
265 | return it; |
266 | } |
267 | }; |
268 | |
269 | // Each use is represented by this type, see Node::uses() |
270 | // 'user' is the consumer of the value, offset is the index into |
271 | // 'user's input this where the produces will be found. |
272 | struct Use final { |
273 | Use(Node* user, size_t offset) : user(user), offset(offset) {} |
274 | Node* user; |
275 | size_t offset; |
276 | }; |
277 | |
278 | static inline bool operator==(const Use& a, const Use& b) { |
279 | return a.user == b.user && a.offset == b.offset; |
280 | } |
281 | |
282 | // the list types are intentionally simple, but we type-def |
283 | // them here so if we need to change them, refactoring will be easier |
284 | using node_list = std::vector<Node*>; |
285 | using value_list = std::vector<Value*>; |
286 | using use_list = std::vector<Use>; |
287 | using NodeKind = Symbol; |
288 | |
289 | struct Value final { |
290 | ONNX_DISALLOW_COPY_AND_ASSIGN(Value); |
291 | Value(Node* node_, size_t offset_); |
292 | Value(Value&&) = default; |
293 | Value& operator=(Value&&) = default; |
294 | ~Value() = default; |
295 | |
296 | private: |
297 | friend struct Node; |
298 | friend struct Graph; |
299 | Node* node_; |
300 | size_t offset_; |
301 | size_t unique_ = 0; // unique id |
302 | size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,... |
303 | use_list uses_in_current_graph_; |
304 | bool has_unique_name_; |
305 | std::string unique_name_; |
306 | int32_t elem_type_; |
307 | bool has_sizes_; |
308 | std::vector<Dimension> sizes_; |
309 | |
310 | public: |
311 | Value* setElemType(int32_t elem_type) { |
312 | elem_type_ = elem_type; |
313 | return this; |
314 | } |
315 | int32_t elemType() const { |
316 | return elem_type_; |
317 | } |
318 | bool has_sizes() const { |
319 | return has_sizes_; |
320 | } |
321 | Value* setSizes(std::vector<Dimension> sizes) { |
322 | has_sizes_ = true; |
323 | sizes_ = std::move(sizes); |
324 | return this; |
325 | } |
326 | Value* wipeSizes() { |
327 | has_sizes_ = false; |
328 | sizes_ = std::vector<Dimension>(); |
329 | return this; |
330 | } |
331 | const std::vector<Dimension>& sizes() const { |
332 | return sizes_; |
333 | } |
334 | size_t unique() const { |
335 | return unique_; |
336 | } |
337 | bool has_unique_name() const { |
338 | return has_unique_name_; |
339 | } |
340 | std::string uniqueName() const { |
341 | if (has_unique_name()) |
342 | return unique_name_; |
343 | return ONNX_NAMESPACE::to_string(unique()); |
344 | } |
345 | Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true); |
346 | Value* setStage(size_t s) { |
347 | stage_ = s; |
348 | return this; |
349 | } |
350 | size_t stage() const { |
351 | return stage_; |
352 | } |
353 | Node* node() { |
354 | return node_; |
355 | } |
356 | size_t offset() const { |
357 | return offset_; |
358 | } |
359 | const Node* node() const { |
360 | return node_; |
361 | } |
362 | Graph* owningGraph(); |
363 | const Graph* owningGraph() const; |
364 | // TODO: make this more const correct |
365 | const use_list uses() const; |
366 | |
367 | // Replaces all uses of this node with 'newValue'. |
368 | // |
369 | // Given: %3 = f(%1, %2) |
370 | // %4 = g(%3) |
371 | // %5 = h(%3, %3) |
372 | // Execute: %3.replaceAllUsesWith(%6) |
373 | // Result: %3 = f(%1, %2) |
374 | // %4 = g(%6) |
375 | // %5 = h(%6, %6) |
376 | void replaceAllUsesWith(Value* newValue); |
377 | |
378 | Value* copyMetadata(Value* from) { |
379 | setElemType(from->elemType()); |
380 | setSizes(from->sizes()); |
381 | if (from->has_unique_name()) { |
382 | setUniqueName(from->uniqueName()); |
383 | } |
384 | return this; |
385 | } |
386 | }; |
387 | |
388 | struct Node : public Attributes<Node> { |
389 | ONNX_DISALLOW_COPY_AND_ASSIGN(Node); |
390 | friend struct Graph; |
391 | friend struct Value; |
392 | friend graph_node_list; |
393 | friend const_graph_node_list; |
394 | friend graph_node_list_iterator; |
395 | friend const_graph_node_list_iterator; |
396 | |
397 | private: |
398 | // each node but Return/Param |
399 | // is associated with exactly one place in the node list... |
400 | // of the graph_ |
401 | // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the |
402 | // list such that the list never has null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev pointer |
403 | // using an array to allow the same iterator class for forward and reverse node lists |
404 | // This list represents a topological sort |
405 | |
406 | Node* next_in_graph[2] = {nullptr, nullptr}; |
407 | Node*& next() { |
408 | return next_in_graph[kNextDirection]; |
409 | } |
410 | Node*& prev() { |
411 | return next_in_graph[kPrevDirection]; |
412 | } |
413 | Node* const& next() const { |
414 | return next_in_graph[kNextDirection]; |
415 | } |
416 | Node* const& prev() const { |
417 | return next_in_graph[kPrevDirection]; |
418 | } |
419 | |
420 | const NodeKind kind_; |
421 | std::vector<Value*> inputs_; |
422 | std::vector<Value*> outputs_; |
423 | Graph* graph_; |
424 | size_t stage_; |
425 | bool has_name_; |
426 | std::string name_; |
427 | bool has_domain_; |
428 | std::string domain_; |
429 | bool has_doc_string_; |
430 | std::string doc_string_; |
431 | |
432 | protected: |
433 | Node(Graph* graph_, NodeKind kind_); // defined after graph |
434 | |
435 | public: |
436 | bool has_name() const { |
437 | return has_name_; |
438 | } |
439 | const std::string& name() const { |
440 | return name_; |
441 | } |
442 | void setName(std::string name) { |
443 | has_name_ = true; |
444 | name_ = std::move(name); |
445 | } |
446 | bool has_domain() const { |
447 | return has_domain_; |
448 | } |
449 | const std::string& domain() const { |
450 | return domain_; |
451 | } |
452 | void setDomain(std::string domain) { |
453 | has_domain_ = true; |
454 | domain_ = std::move(domain); |
455 | } |
456 | bool has_doc_string() const { |
457 | return has_doc_string_; |
458 | } |
459 | const std::string& docString() const { |
460 | return doc_string_; |
461 | } |
462 | void setDocString(std::string doc_string) { |
463 | has_doc_string_ = true; |
464 | doc_string_ = std::move(doc_string); |
465 | } |
466 | NodeKind kind() const { |
467 | return kind_; |
468 | } |
469 | Graph* owningGraph() { |
470 | return graph_; |
471 | } |
472 | const Graph* owningGraph() const { |
473 | return graph_; |
474 | } |
475 | size_t stage() const { |
476 | return stage_; |
477 | } |
478 | Node* setStage(size_t s) { |
479 | stage_ = s; |
480 | return this; |
481 | } |
482 | // NB: This returns an ArrayRef; that means that it will |
483 | // get invalidated if you resize inputs (e.g., using addInput) |
484 | // We can't return a std::vector<Node*>& because there's no |
485 | // way to soundly cast to std::vector<const Node*> (an insane |
486 | // implementation of std::vector could make this representationally |
487 | // different.) |
488 | ArrayRef<Value*> inputs() { |
489 | return inputs_; |
490 | } |
491 | ArrayRef<const Value*> inputs() const { |
492 | // Vectors are not convertible in const-ness of elements, but |
493 | // raw pointers are. |
494 | return {inputs_.data(), inputs_.size()}; |
495 | } |
496 | // NB: This returns an ArrayRef; that means that it will |
497 | // get invalidated if you resize inputs (e.g., using addInput) |
498 | // We can't return a std::vector<Node*>& because there's no |
499 | // way to soundly cast to std::vector<const Node*> (an insane |
500 | // implementation of std::vector could make this representationally |
501 | // different.) |
502 | ArrayRef<Value*> outputs() { |
503 | return outputs_; |
504 | } |
505 | ArrayRef<const Value*> outputs() const { |
506 | // Vectors are not convertible in const-ness of elements, but |
507 | // raw pointers are. |
508 | return {outputs_.data(), outputs_.size()}; |
509 | } |
510 | bool hasUses() const { |
511 | for (auto o : outputs()) { |
512 | if (!o->uses().empty()) |
513 | return true; |
514 | } |
515 | return false; |
516 | } |
517 | void replaceAllUsesWith(Node* n) { |
518 | ONNX_ASSERT(outputs().size() == n->outputs().size()); |
519 | size_t nOutputs = outputs().size(); |
520 | for (size_t i = 0; i < nOutputs; i++) { |
521 | outputs()[i]->replaceAllUsesWith(n->outputs()[i]); |
522 | } |
523 | } |
524 | // lots of things like chunk have a single input or single output, so we have a |
525 | // helper to make accessing it easier |
526 | Value* input() { |
527 | ONNX_ASSERT(inputs_.size() == 1); |
528 | return inputs_.at(0); |
529 | } |
530 | Value* output() { |
531 | ONNX_ASSERT(outputs_.size() == 1); |
532 | return outputs_.at(0); |
533 | } |
534 | const Value* input() const { |
535 | ONNX_ASSERT(inputs_.size() == 1); |
536 | return inputs_.at(0); |
537 | } |
538 | Value* output() const { |
539 | ONNX_ASSERT(outputs_.size() == 1); |
540 | return outputs_.at(0); |
541 | } |
542 | // Access a particular input. This is a checked index. |
543 | Value* input(size_t i) { |
544 | return inputs_.at(i); |
545 | } |
546 | const Value* input(size_t i) const { |
547 | return inputs_.at(i); |
548 | } |
549 | |
550 | // Graphs |
551 | |
552 | // Note [Topological invariant] |
553 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
554 | // We always maintain an up-to-date topological ordering of all nodes via |
555 | // the next()/prev() links. All transformations to graphs must preserve |
556 | // this topological ordering: for example, it is only valid to 'addInput' |
557 | // with an input which is topologically before the current node. |
558 | // |
559 | // Usually, it is obvious whether or not topological order is maintained; |
560 | // for example, if you are adding nodes to the end of the topsort, it's |
561 | // impossible for them to refer to inputs that are not in the topsort. |
562 | // If it is not obvious, please comment accordingly. |
563 | |
564 | // Add 'node' as an input to 'this' at the end of existing |
565 | // arguments. Returns the added node for ease of chaining. |
566 | // |
567 | // Given: %3 = f(%1, %2) |
568 | // Execute: %3.addInput(%4) |
569 | // Result: %3 = f(%1, %2, %4) |
570 | Value* addInput(Value* node) { |
571 | ONNX_ASSERT(graph_ == node->owningGraph()); |
572 | node->uses_in_current_graph_.emplace_back(this, inputs_.size()); |
573 | inputs_.push_back(node); |
574 | return node; |
575 | } |
576 | |
577 | // Replace the input of 'this' at position 'i' with |
578 | // 'newValue', returning the old node. |
579 | // |
580 | // Given: %3 = f(%1, %2) |
581 | // Execute: %3.replaceInput(1, %4) |
582 | // Result: %3 = f(%1, %4) |
583 | Value* replaceInput(size_t i, Value* newValue) { |
584 | ONNX_ASSERT(newValue->owningGraph() == graph_); |
585 | Value* old = dropInput(i); |
586 | inputs_[i] = newValue; |
587 | newValue->uses_in_current_graph_.emplace_back(this, i); |
588 | return old; |
589 | } |
590 | |
591 | // Replace all occurrences of 'from' in the inputs of this |
592 | // node with 'to'. Corresponds to llvm's replaceUsesOfWith. |
593 | // |
594 | // Given: %3 = f(%1, %2, %1) |
595 | // Execute: %3.replaceInputWith(%1, %4) |
596 | // Result: %3 = f(%4, %2, %4) |
597 | void replaceInputWith(Value* from, Value* to) { |
598 | ONNX_ASSERT(from->owningGraph() == graph_); |
599 | ONNX_ASSERT(to->owningGraph() == graph_); |
600 | size_t i = 0; |
601 | for (auto input : inputs()) { |
602 | if (input == from) |
603 | replaceInput(i, to); |
604 | i++; |
605 | } |
606 | } |
607 | |
608 | Value* addOutput() { |
609 | outputs_.push_back(new Value(this, outputs_.size())); |
610 | return outputs_.back(); |
611 | } |
612 | |
613 | void eraseOutput(size_t i); |
614 | |
615 | // Insert unattached 'this' node after 'n' in the topological order. |
616 | // Returns this (for chaining). |
617 | // |
618 | // Given: %3 = f(%1, %2) |
619 | // %4 = g(%3) |
620 | // and unattached: %5 = h(%1) |
621 | // Execute: %5.insertBefore(%4) |
622 | // Result: %3 = f(%1, %2) |
623 | // %5 = h(%1) |
624 | // %4 = g(%3) |
625 | Node* insertBefore(Node* n) { |
626 | ONNX_ASSERT(n->inGraphList()); |
627 | insertAfter(n->prev()); |
628 | return this; |
629 | } |
630 | |
631 | // Insert unattached 'this' node after 'n' in the topological order. |
632 | // Returns this (for chaining). |
633 | // |
634 | // Given: %3 = f(%1, %2) |
635 | // %4 = g(%3) |
636 | // and unattached: %5 = h(%1) |
637 | // Execute: %5.insertAfter(%4) |
638 | // Result: %3 = f(%1, %2) |
639 | // %4 = g(%3) |
640 | // %5 = h(%1) |
641 | Node* insertAfter(Node* n) { |
642 | ONNX_ASSERT(!inGraphList() && n->inGraphList()); |
643 | Node* next = n->next(); |
644 | n->next() = this; |
645 | this->prev() = n; |
646 | this->next() = next; |
647 | next->prev() = this; |
648 | return this; |
649 | } |
650 | |
651 | // Move 'this' (already in the graph) after 'n' in the topological order. |
652 | // |
653 | // Given: %2 = f(%1) |
654 | // %3 = g(%1) |
655 | // Execute: %2.moveAfter(%3) |
656 | // Result: %3 = g(%1) |
657 | // %2 = f(%1) |
658 | // |
659 | void moveAfter(Node* n) { |
660 | removeFromList(); |
661 | insertAfter(n); |
662 | } |
663 | |
664 | // Move a node 'n' (already in the graph) before 'this' in the topological order. |
665 | // |
666 | // Given: %2 = f(%1) |
667 | // %3 = g(%1) |
668 | // Execute: %3.moveBefore(%2) |
669 | // Result: %3 = g(%1) |
670 | // %2 = f(%1) |
671 | void moveBefore(Node* n) { |
672 | removeFromList(); |
673 | insertBefore(n); |
674 | } |
675 | |
676 | // Remove the input at 'i' from this node. |
677 | // |
678 | // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling |
679 | // removeInput. |
680 | // |
681 | // Given: %3 = f(%1, %2) |
682 | // Execute: %3.removeInput(1) |
683 | // Result: %3 = f(%1) |
684 | void removeInput(size_t i) { |
685 | dropInput(i); |
686 | // everything after this input shifts left, |
687 | // so we need to update their use offsets to match |
688 | for (size_t j = i + 1; j < inputs_.size(); j++) { |
689 | auto it = findUseForInput(j); |
690 | it->offset--; |
691 | } |
692 | inputs_.erase(inputs_.begin() + i); |
693 | } |
694 | |
695 | // Remove all inputs from a node. |
696 | // |
697 | // Given: %3 = f(%1, %2) |
698 | // Execute: %3.removeAllInputs() |
699 | // Result: %3 = f() |
700 | void removeAllInputs() { |
701 | for (size_t i = 0; i < inputs().size(); ++i) |
702 | dropInput(i); |
703 | inputs_.clear(); |
704 | } |
705 | |
706 | // Check whether this node is before node n in the graph. |
707 | bool isBefore(Node* n); |
708 | |
709 | // iterators of the node list starting at this node |
710 | // useful for resuming a search starting at this node |
711 | graph_node_list_iterator iterator(); |
712 | graph_node_list_iterator reverseIterator(); |
713 | const_graph_node_list_iterator iterator() const; |
714 | const_graph_node_list_iterator reverseIterator() const; |
715 | |
716 | // Remove 'this' from the instruction list and deallocate it. |
717 | // |
718 | // Invariant: no outputs of 'this' may have any uses. |
719 | // |
720 | // Given: %2 = f(%1) |
721 | // %3 = g(%1) |
722 | // Execute: %2.destroy() |
723 | // Result: %3 = g(%1) |
724 | void destroy(); |
725 | |
726 | // Dynamically cast this node to the subclass indicated by the |
727 | // template variable, returning nullptr if the cast is invalid.. |
728 | // |
729 | // Example usage: if(auto s = n.cast<Select>()) { ... } |
730 | // |
731 | // TODO: Make this const correct |
732 | template <typename T> |
733 | T* cast() { |
734 | if (T::Kind == kind()) |
735 | return static_cast<T*>(this); |
736 | return nullptr; |
737 | } |
738 | template <typename T> |
739 | T* expect() { |
740 | ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s" , T::Kind.toString(), kind().toString()); |
741 | return static_cast<T*>(this); |
742 | } |
743 | |
744 | virtual ~Node() = default; |
745 | |
746 | private: |
747 | // Lookup iterator in use list of _input i_ that corresponds to its use of _this_ |
748 | use_list::iterator findUseForInput(size_t i) { |
749 | auto& input_uses = inputs_[i]->uses_in_current_graph_; |
750 | // O(N) on the use list, but unless we get nodes with +100 uses |
751 | // vector traversal still is probably faster than linked list |
752 | auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i)); |
753 | ONNX_ASSERT(use_it != input_uses.end()); |
754 | return use_it; |
755 | } |
756 | |
757 | // remove the use of input i, this sets input i to nullptr, but |
758 | // is only used internally to Node before setting it to a new value |
759 | // or erasing the entry from the list. |
760 | Value* dropInput(size_t i) { |
761 | ONNX_ASSERT(i < inputs_.size()); |
762 | auto input_node = inputs_[i]; |
763 | auto use_it = findUseForInput(i); |
764 | input_node->uses_in_current_graph_.erase(use_it); |
765 | inputs_[i] = nullptr; |
766 | return input_node; |
767 | } |
768 | |
769 | bool inGraphList() const { |
770 | ONNX_ASSERT(next() != nullptr || prev() == nullptr); |
771 | return next() != nullptr; |
772 | } |
773 | void removeFromList() { |
774 | ONNX_ASSERT(inGraphList()); |
775 | Node* next = this->next(); |
776 | Node* prev = this->prev(); |
777 | prev->next() = next; |
778 | next->prev() = prev; |
779 | this->next() = nullptr; |
780 | this->prev() = nullptr; |
781 | } |
782 | |
783 | protected: |
784 | // subclasses must override |
785 | // this function is used by createClone to initialize a new version |
786 | // of a node in another graph. It should allocate a new instance of the same |
787 | // concrete type as 'this', but in graph 'g' which might be different |
788 | // than graph_ |
789 | virtual Node* allocNewInstance(Graph* g) { |
790 | return new Node(g, kind()); |
791 | } |
792 | // create a copy of all properties of Node s into this. |
793 | // subclasses should extend if they have additional information to copy. |
794 | // 'this' will be allocated with s->allocNewInstance(g) so it should have |
795 | // the same concrete type as 's' |
796 | // |
797 | // NB: This does NOT clone stages. You're expected to set the stage correctly |
798 | // if you are going to preserve it. |
799 | virtual void cloneFrom(Node* s) { |
800 | copyAttributes(*s); |
801 | } |
802 | }; |
803 | |
804 | // A class with the same properties as OperatorSetIdProto, but without protobuf |
805 | // overhead, resulting in a simpler and more readable workflow. |
806 | class OpSetID final { |
807 | private: |
808 | std::string domain_; |
809 | int64_t version_; |
810 | |
811 | public: |
812 | explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {} |
813 | |
814 | // Default Domain Constructor |
815 | explicit OpSetID(const int64_t version) : domain_("" ), version_(version) {} |
816 | |
817 | explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {} |
818 | |
819 | // target must be in the form "<domain>&<version>" |
820 | std::string toString() const { |
821 | return domain_ + "$" + ONNX_NAMESPACE::to_string(version_); |
822 | } |
823 | |
824 | // target must be in the form "<domain>&<version>" |
825 | static OpSetID fromString(const std::string& target) { |
826 | ONNX_TRY { |
827 | std::string new_domain = target.substr(0, target.find("$" )); |
828 | int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$" ) + 1, target.length()).c_str()); |
829 | return OpSetID(new_domain, new_version); |
830 | } |
831 | ONNX_CATCH(const std::runtime_error& e) { |
832 | ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s" , e.what()); }); |
833 | } |
834 | |
835 | // The control will never reach here. |
836 | // In the default build where exceptions are turned on in case of any error |
837 | // the control will enter catch block where an exception will be thrown again. |
838 | // In case of "no exception build" the code aborts at the site of first exception. |
839 | // Adding this to appease the warning "control may reach end of non-void function" |
840 | // as the mac build fails when ONNX_WERROR==ON |
841 | return OpSetID("" , 0); |
842 | } |
843 | |
844 | const std::string& domain() const { |
845 | return domain_; |
846 | } |
847 | |
848 | int64_t version() const { |
849 | return version_; |
850 | } |
851 | |
852 | void incrementVersion(int64_t step) { |
853 | version_ += step; |
854 | } |
855 | |
856 | void setVersion(int64_t newVal) { |
857 | version_ = newVal; |
858 | } |
859 | }; |
860 | |
861 | struct Graph final { |
862 | ONNX_DISALLOW_COPY_AND_ASSIGN(Graph); |
863 | friend struct Node; |
864 | friend struct Value; |
865 | |
866 | private: |
867 | // only used to keep track of allocated nodes |
868 | // actual representation of Graph is done with |
869 | // inputs, outputs, nodes |
870 | |
871 | std::unordered_set<const Node*> all_nodes; |
872 | std::unordered_set<const Value*> all_values; |
873 | size_t next_unique_; |
874 | |
875 | size_t new_node_stage_; |
876 | |
877 | // holds outputs in a way that can be reflected |
878 | // as a Use object |
879 | // also used as the beginning/end of the circular node list to avoid |
880 | // having corner cases where the list is empty. |
881 | Node* const output_; |
882 | Node* const input_; |
883 | // Create an independent node list for those initializers do not exist in input |
884 | Node* const initializer_node_; |
885 | |
886 | std::vector<Tensor> initializers_; |
887 | std::vector<std::string> initializer_names_; |
888 | |
889 | bool has_name_; |
890 | std::string name_; |
891 | bool has_doc_string_; |
892 | std::string doc_string_; |
893 | |
894 | std::vector<OpSetID> opset_versions_; |
895 | |
896 | bool isNameUnique(const std::string& name) const { |
897 | if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) { |
898 | return false; |
899 | } |
900 | const auto f = [&name](const Value* v) { return v->uniqueName() == name; }; |
901 | for (const Node* node : all_nodes) { |
902 | for (const auto& attr : node->attributeNames()) { |
903 | if (node->kindOf(attr) == AttributeKind::g) { |
904 | const auto& subgraph = node->g(attr); |
905 | if (!subgraph->isNameUnique(name)) { |
906 | return false; |
907 | } |
908 | } else if (node->kindOf(attr) == AttributeKind::gs) { |
909 | for (const auto& subgraph : node->gs(attr)) { |
910 | if (!subgraph->isNameUnique(name)) { |
911 | return false; |
912 | } |
913 | } |
914 | } |
915 | } |
916 | const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f); |
917 | if (found_in != node->inputs().end()) { |
918 | return false; |
919 | } |
920 | const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f); |
921 | if (found_out != node->outputs().end()) { |
922 | return false; |
923 | } |
924 | } |
925 | return true; |
926 | } |
927 | |
928 | public: |
929 | Graph() |
930 | : next_unique_(0), |
931 | new_node_stage_(0), |
932 | output_(initOutput(create(kReturn, 0))), |
933 | input_(create(kParam, 0)), |
934 | initializer_node_(create(kParam, 0)), |
935 | has_name_(false), |
936 | has_doc_string_(false) {} |
937 | |
938 | bool has_doc_string() const { |
939 | return has_doc_string_; |
940 | } |
941 | const std::string& docString() { |
942 | return doc_string_; |
943 | } |
944 | void setDocString(std::string doc_string) { |
945 | has_doc_string_ = true; |
946 | doc_string_ = std::move(doc_string); |
947 | } |
948 | |
949 | void addInitializer(Tensor& initializer) { |
950 | if (initializer.name().empty()) { |
951 | initializer.setName(ONNX_NAMESPACE::to_string(getNextUnique())); |
952 | } |
953 | initializers_.push_back(initializer); |
954 | initializer_names_.push_back(initializer.name()); |
955 | } |
956 | |
957 | // For IR >= 4, initializer is not required to exist in input |
958 | // Add initializer into initializer node list and return its Value |
959 | Value* addInitializerAndCreateValue(Tensor& initializer) { |
960 | addInitializer(initializer); |
961 | auto* init_value = initializer_node_->addOutput(); |
962 | std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()}; |
963 | init_value->setUniqueName(initializer.name()); |
964 | init_value->setSizes(dim_sizes); |
965 | init_value->setElemType(initializer.elem_type()); |
966 | return init_value; |
967 | } |
968 | |
969 | void eraseInitializer(const std::string& name) { |
970 | initializers_.erase( |
971 | std::remove_if( |
972 | initializers_.begin(), |
973 | initializers_.end(), |
974 | [&name](Tensor& initializer) { return initializer.name() == name; }), |
975 | initializers_.end()); |
976 | initializer_names_.erase( |
977 | std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end()); |
978 | for (size_t i = 0; i < initializer_node_->outputs().size(); i++) { |
979 | if (initializer_node_->outputs()[i]->uniqueName() == name) { |
980 | initializer_node_->eraseOutput(i); |
981 | break; |
982 | } |
983 | } |
984 | } |
985 | void clearInitializers() { |
986 | initializers_.clear(); |
987 | initializer_names_.clear(); |
988 | } |
989 | const std::vector<Tensor>& initializers() const { |
990 | return initializers_; |
991 | } |
992 | const std::vector<std::string>& initializer_names() const { |
993 | return initializer_names_; |
994 | } |
995 | std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const { |
996 | for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) { |
997 | if (name == it->name()) { |
998 | return it; |
999 | } |
1000 | } |
1001 | return initializers_.end(); |
1002 | } |
1003 | bool is_constant_initializer(const Value* value) const { |
1004 | return value->node() == initializer_node_; |
1005 | } |
1006 | ArrayRef<Value*> inputs() { |
1007 | return input_->outputs(); |
1008 | } |
1009 | ArrayRef<const Value*> inputs() const { |
1010 | const auto& inputs = input_->outputs(); |
1011 | return {inputs.data(), inputs.size()}; |
1012 | } |
1013 | ArrayRef<Value*> outputs() { |
1014 | return output_->inputs(); |
1015 | } |
1016 | ArrayRef<const Value*> outputs() const { |
1017 | return static_cast<const Node*>(output_)->inputs(); |
1018 | } |
1019 | graph_node_list nodes() { |
1020 | return graph_node_list(output_, kNextDirection); |
1021 | } |
1022 | const_graph_node_list nodes() const { |
1023 | return const_graph_node_list(output_, kNextDirection); |
1024 | } |
1025 | |
1026 | std::vector<OpSetID>& opset_versions_mutable() { |
1027 | return opset_versions_; |
1028 | } |
1029 | |
1030 | size_t getNextUnique() { |
1031 | std::string next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_); |
1032 | while (!isNameUnique(next_unique_name)) { |
1033 | next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_); |
1034 | } |
1035 | return next_unique_; |
1036 | } |
1037 | |
1038 | // These invocations of begin() on output of function are OK |
1039 | // because graph_node_list is non-owning, so it doesn't matter |
1040 | // if it immediately dies after the invocation. |
1041 | graph_node_list_iterator begin() { |
1042 | return nodes().begin(); |
1043 | } |
1044 | const_graph_node_list_iterator begin() const { |
1045 | return nodes().begin(); |
1046 | } |
1047 | graph_node_list_iterator end() { |
1048 | return nodes().end(); |
1049 | } |
1050 | const_graph_node_list_iterator end() const { |
1051 | return nodes().end(); |
1052 | } |
1053 | graph_node_list_iterator rbegin() { |
1054 | return nodes().rbegin(); |
1055 | } |
1056 | const_graph_node_list_iterator rbegin() const { |
1057 | return nodes().rbegin(); |
1058 | } |
1059 | graph_node_list_iterator rend() { |
1060 | return nodes().rend(); |
1061 | } |
1062 | const_graph_node_list_iterator rend() const { |
1063 | return nodes().rend(); |
1064 | } |
1065 | Node* return_node() { |
1066 | return output_; |
1067 | } |
1068 | const Node* return_node() const { |
1069 | return output_; |
1070 | } |
1071 | |
1072 | Value* addInput() { |
1073 | return input_->addOutput(); |
1074 | } |
1075 | void eraseInput(size_t i) { |
1076 | input_->eraseOutput(i); |
1077 | } |
1078 | void advanceStage() { |
1079 | new_node_stage_++; |
1080 | } |
1081 | void setStage(size_t new_stage) { |
1082 | new_node_stage_ = new_stage; |
1083 | } |
1084 | size_t stage() const { |
1085 | return new_node_stage_; |
1086 | } |
1087 | ResourceGuard setStageTemporary(size_t s) { |
1088 | auto prev_stage = new_node_stage_; |
1089 | new_node_stage_ = s; |
1090 | return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; }); |
1091 | } |
1092 | |
1093 | size_t registerOutput(Value* n) { |
1094 | output_->addInput(n); |
1095 | return outputs().size() - 1; |
1096 | } |
1097 | |
1098 | Node* create(NodeKind kind, size_t num_outputs = 1) { |
1099 | // NB: Node constructor adds node to all_nodes |
1100 | auto n = new Node(this, kind); |
1101 | for (size_t i = 0; i < num_outputs; i++) |
1102 | n->addOutput(); |
1103 | return n; |
1104 | } |
1105 | |
1106 | Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) { |
1107 | auto n = create(kind, num_outputs); |
1108 | for (auto i : inputs) |
1109 | n->addInput(i); |
1110 | return n; |
1111 | } |
1112 | |
1113 | Node* appendNode(Node* n) { |
1114 | ONNX_ASSERT(n->graph_ == this && !n->inGraphList()); |
1115 | n->insertBefore(output_); |
1116 | return n; |
1117 | } |
1118 | |
1119 | Node* prependNode(Node* n) { |
1120 | ONNX_ASSERT(n->graph_ == this && !n->inGraphList()); |
1121 | n->insertAfter(output_); |
1122 | return n; |
1123 | } |
1124 | |
1125 | // Adds to graph initializer list, initializer names list, and as a graph input |
1126 | // Also syncs the initializer name, tensor name, and value name |
1127 | // Create an initializer whose value is stored in input |
1128 | Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) { |
1129 | Tensor initializerCopy = initializer; |
1130 | std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()}; |
1131 | Value* new_init = addInput(); |
1132 | initializerCopy.setName(name); |
1133 | new_init->setUniqueName(name); |
1134 | new_init->setSizes(dim_sizes); |
1135 | new_init->setElemType(initializerCopy.elem_type()); |
1136 | addInitializer(initializerCopy); |
1137 | return new_init; |
1138 | } |
1139 | |
1140 | Value* addInitializerAndInput(const Tensor& initializer) { |
1141 | return addInitializerAndInput(initializer, ONNX_NAMESPACE::to_string(getNextUnique())); |
1142 | } |
1143 | |
1144 | // Erases from graph initializer list, initializer names list, and as a graph input |
1145 | // Must have no uses |
1146 | void eraseInitializerAndInput(Value* v) { |
1147 | eraseInitializer(v->uniqueName()); |
1148 | if (v->node() == input_) { |
1149 | eraseInput(v->offset()); |
1150 | } |
1151 | } |
1152 | |
1153 | ~Graph() { |
1154 | for (const Node* n : all_nodes) |
1155 | delete n; |
1156 | for (const Value* v : all_values) |
1157 | delete v; |
1158 | } |
1159 | |
1160 | std::string toString() const { |
1161 | std::ostringstream oss; |
1162 | oss << *this; |
1163 | return oss.str(); |
1164 | } |
1165 | |
1166 | bool has_name() const { |
1167 | return has_name_; |
1168 | } |
1169 | |
1170 | const std::string& name() const { |
1171 | return name_; |
1172 | } |
1173 | |
1174 | void setName(std::string name) { |
1175 | has_name_ = true; |
1176 | name_ = std::move(name); |
1177 | } |
1178 | |
1179 | friend std::ostream& operator<<(std::ostream& out, const Graph& g); |
1180 | |
1181 | void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) { |
1182 | fn(this); |
1183 | |
1184 | for (const Node* node : all_nodes) { |
1185 | for (const auto& attr : node->attributeNames()) { |
1186 | if (node->kindOf(attr) == AttributeKind::g) { |
1187 | std::shared_ptr<Graph> subgraph = node->g(attr); |
1188 | subgraph->forSelfAndEachSubGraph(fn); |
1189 | } else if (node->kindOf(attr) == AttributeKind::gs) { |
1190 | for (const auto& subgraph : node->gs(attr)) { |
1191 | subgraph->forSelfAndEachSubGraph(fn); |
1192 | } |
1193 | } |
1194 | } |
1195 | } |
1196 | } |
1197 | |
1198 | void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const { |
1199 | std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); }; |
1200 | const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn); |
1201 | } |
1202 | |
1203 | void forEachNode(const std::function<void(Node*)>& fn) { |
1204 | forSelfAndEachSubGraph([fn](Graph* graph) { |
1205 | for (Node* node : graph->nodes()) { |
1206 | fn(node); |
1207 | } |
1208 | }); |
1209 | } |
1210 | |
1211 | void forEachNode(const std::function<void(const Node*)>& fn) const { |
1212 | std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); }; |
1213 | const_cast<Graph*>(this)->forEachNode(tmp_fn); |
1214 | } |
1215 | |
1216 | private: |
1217 | // should only be called in the constructor |
1218 | Node* initOutput(Node* p) { |
1219 | p->next() = p; |
1220 | p->prev() = p; |
1221 | p->setStage(std::numeric_limits<size_t>::max()); |
1222 | return p; |
1223 | } |
1224 | |
1225 | void freeNode(Node* n) { |
1226 | auto it = all_nodes.find(n); |
1227 | ONNX_ASSERT(it != all_nodes.end()); |
1228 | delete *it; |
1229 | all_nodes.erase(it); |
1230 | } |
1231 | void freeValue(Value* v) { |
1232 | auto it = all_values.find(v); |
1233 | ONNX_ASSERT(it != all_values.end()); |
1234 | delete *it; |
1235 | all_values.erase(it); |
1236 | } |
1237 | }; |
1238 | |
1239 | inline Value::Value(Node* node_, size_t offset_) |
1240 | : node_(node_), |
1241 | offset_(offset_), |
1242 | unique_(node_->graph_->getNextUnique()), |
1243 | stage_(node_->graph_->new_node_stage_), |
1244 | has_unique_name_(false), |
1245 | elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED), |
1246 | has_sizes_(false) { |
1247 | node_->graph_->all_values.emplace(this); |
1248 | } |
1249 | |
1250 | inline Graph* Value::owningGraph() { |
1251 | return node()->owningGraph(); |
1252 | } |
1253 | |
1254 | inline const Graph* Value::owningGraph() const { |
1255 | return node()->owningGraph(); |
1256 | } |
1257 | |
1258 | // `captured` nodes in subgraph determines which value it captures |
1259 | // by storing the value's unique name, so old unique names in `captured` nodes |
1260 | // should also be updated. |
1261 | // Initializer names are also storaged in graph.initializer_names_, it should be |
1262 | // updated too. |
1263 | inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) { |
1264 | if (has_unique_name() && update_related_names) { |
1265 | auto* graph = owningGraph(); |
1266 | auto old_name = unique_name_; |
1267 | for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) { |
1268 | auto& initializer_name = owningGraph()->initializer_names_[i]; |
1269 | if (initializer_name == old_name) { |
1270 | initializer_name = name; |
1271 | owningGraph()->initializers_[i].setName(name); |
1272 | } |
1273 | } |
1274 | graph->forEachNode([this, &name, &old_name](Node* node) { |
1275 | if (node->owningGraph() == this->owningGraph()) { |
1276 | // skip non-subgraph |
1277 | return; |
1278 | } |
1279 | if (node->kind() == kCaptured) { |
1280 | Value* output = node->output(); |
1281 | if (output->uniqueName() == old_name) { |
1282 | output->setUniqueName(name, false); |
1283 | } |
1284 | } |
1285 | }); |
1286 | } |
1287 | unique_name_ = name; |
1288 | has_unique_name_ = true; |
1289 | return this; |
1290 | } |
1291 | |
1292 | inline void Value::replaceAllUsesWith(Value* newValue) { |
1293 | auto* graph = owningGraph(); |
1294 | ONNX_ASSERT(graph == newValue->owningGraph()); |
1295 | // propagate sizes and elem type |
1296 | if (this->has_sizes()) { |
1297 | newValue->setSizes(this->sizes()); |
1298 | } |
1299 | if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { |
1300 | newValue->setElemType(this->elemType()); |
1301 | } |
1302 | const auto unique_name = this->uniqueName(); |
1303 | // We do not want the optimization to change the graph output name |
1304 | if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) { |
1305 | newValue->setUniqueName(unique_name); |
1306 | // The "unique" semantic of unique_name should be kept or uses() |
1307 | // will return an incorrect result when the value is used in subgraph |
1308 | this->setUniqueName(ONNX_NAMESPACE::to_string(graph->getNextUnique()), false); |
1309 | } |
1310 | newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size()); |
1311 | for (auto u : uses_in_current_graph_) { |
1312 | u.user->inputs_[u.offset] = newValue; |
1313 | newValue->uses_in_current_graph_.push_back(u); |
1314 | } |
1315 | graph->forEachNode([this, &newValue, &unique_name](Node* node) { |
1316 | if (node->owningGraph() == this->owningGraph()) { |
1317 | // skip non-subgraph |
1318 | return; |
1319 | } |
1320 | if (node->kind() == kCaptured) { |
1321 | Value* output = node->output(); |
1322 | if (output->uniqueName() == unique_name) { |
1323 | output->setUniqueName(newValue->uniqueName()); |
1324 | } |
1325 | } |
1326 | }); |
1327 | uses_in_current_graph_.clear(); |
1328 | assert(this->uses().empty()); |
1329 | } |
1330 | |
1331 | inline Node::Node(Graph* graph_, NodeKind kind_) |
1332 | : kind_(kind_), |
1333 | graph_(graph_), |
1334 | stage_(graph_->new_node_stage_), |
1335 | has_name_(false), |
1336 | has_domain_(false), |
1337 | has_doc_string_(false) { |
1338 | graph_->all_nodes.emplace(this); |
1339 | } |
1340 | |
1341 | inline void Node::eraseOutput(size_t i) { |
1342 | ONNX_ASSERT(i < outputs_.size()); |
1343 | ONNX_ASSERT(outputs_[i]->uses().empty()); |
1344 | Value* n = outputs_[i]; |
1345 | outputs_.erase(outputs_.begin() + i); |
1346 | owningGraph()->freeValue(n); |
1347 | for (size_t j = i; j < outputs_.size(); j++) { |
1348 | outputs_[j]->offset_--; |
1349 | } |
1350 | } |
1351 | |
1352 | inline bool Node::isBefore(Node* n) { |
1353 | if (n == nullptr || this == n) { |
1354 | // Bail out early. |
1355 | return false; |
1356 | } |
1357 | // return true if node is Param (in initializers) |
1358 | if (kind_ == kParam) { |
1359 | return true; |
1360 | } |
1361 | // return false if target node is Param (in initializers) |
1362 | if (n->kind() == kParam) { |
1363 | return false; |
1364 | } |
1365 | ONNX_ASSERT(n->inGraphList()); |
1366 | for (Node* p = next(); p != *graph_->end(); p = p->next()) { |
1367 | if (p == n) { |
1368 | return true; |
1369 | } |
1370 | } |
1371 | return false; |
1372 | } |
1373 | |
1374 | inline void Node::destroy() { |
1375 | ONNX_ASSERT(inGraphList()); |
1376 | while (!outputs().empty()) |
1377 | eraseOutput(outputs().size() - 1); |
1378 | removeAllInputs(); |
1379 | removeFromList(); |
1380 | graph_->freeNode(this); |
1381 | } |
1382 | |
1383 | /************* All nodes not required to be defined before Graph **************/ |
1384 | |
1385 | inline graph_node_list_iterator Node::iterator() { |
1386 | return graph_node_list_iterator(this, 0); |
1387 | } |
1388 | inline graph_node_list_iterator Node::reverseIterator() { |
1389 | return iterator().reverse(); |
1390 | } |
1391 | inline const_graph_node_list_iterator Node::iterator() const { |
1392 | return const_graph_node_list_iterator(this, 0); |
1393 | } |
1394 | inline const_graph_node_list_iterator Node::reverseIterator() const { |
1395 | return iterator().reverse(); |
1396 | } |
1397 | |
1398 | // Returns a list about which nodes are using this value, |
1399 | // nodes in subgraph are also included. |
1400 | // This method is usually used to check whether it is |
1401 | // safe to delete a Value. |
1402 | inline const use_list Value::uses() const { |
1403 | use_list all_uses = uses_in_current_graph_; |
1404 | owningGraph()->forEachNode([this, &all_uses](const Node* node) { |
1405 | if (node->owningGraph() == this->owningGraph()) { |
1406 | // skip non-subgraph |
1407 | return; |
1408 | } |
1409 | if (node->kind() == kCaptured) { |
1410 | const Value* output = node->outputs()[0]; |
1411 | if (output->uniqueName() == this->uniqueName()) { |
1412 | const auto output_uses = output->uses(); |
1413 | all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end()); |
1414 | } |
1415 | } |
1416 | }); |
1417 | return all_uses; |
1418 | } |
1419 | |
1420 | } // namespace ONNX_NAMESPACE |
1421 | |