1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5// ATTENTION: The code in this file is highly EXPERIMENTAL.
6// Adventurous users should note that the APIs will probably change.
7
8#pragma once
9
10#include <stdint.h>
11#include <algorithm>
12#include <atomic>
13#include <cstdint>
14#include <functional>
15#include <iostream>
16#include <memory>
17#include <sstream>
18#include <string>
19#include <unordered_set>
20#include <vector>
21
22#include "onnx/common/array_ref.h"
23#include "onnx/common/assertions.h"
24#include "onnx/common/common.h"
25#include "onnx/common/graph_node_list.h"
26#include "onnx/common/interned_strings.h"
27#include "onnx/common/tensor.h"
28#include "onnx/string_utils.h"
29
30#define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
31 TypeName(const TypeName&) = delete; \
32 TypeName& operator=(const TypeName&) = delete
33
34namespace ONNX_NAMESPACE {
35
36// Graph represents one "function" of computation.
37// It uses a simple ownership model where the graph owns all the nodes inside it.
38// All references inside the graph are raw pointers.
39// Destroying the Graph will invalidate any pointers to nodes in the graph.
40struct Graph;
41
42// Node is the base class of the IR graph. It represents one computation
43// and dependencies on a list of Values. The "prim-ops", so to speak.
44struct Node;
45
46// A Value represents an input or output to node that is either a
47// Tensor or an opaque Handle object, as determined by type().
48struct Value;
49
50class ResourceGuard final {
51 std::function<void()> destructor_;
52 bool released_;
53
54 public:
55 ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard);
56 explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {}
57 ResourceGuard(ResourceGuard&& other) = default;
58 ResourceGuard& operator=(ResourceGuard&& other) = default;
59
60 ~ResourceGuard() {
61 if (!released_)
62 destructor_();
63 }
64
65 void release() {
66 released_ = true;
67 }
68};
69
70struct Dimension final {
71 Dimension() : is_unknown(true), is_int(false), dim(-1) {}
72 Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {} // NOLINT
73 Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {} // NOLINT
74
75 bool is_unknown;
76 bool is_int;
77 int64_t dim;
78 std::string param;
79};
80
81enum class AttributeKind : uint8_t {
82 // float, float list, int, int list, string, string list,
83 // tensor, tensor list, subgraph, subgraph list. type proto, type proto list
84 f,
85 fs,
86 i,
87 is,
88 s,
89 ss,
90 t,
91 ts,
92 g,
93 gs,
94 tp,
95 tps
96};
97
98static inline const char* toString(AttributeKind kind) {
99 static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
100 ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
101 return names[int(kind)];
102}
103
104struct AttributeValue {
105 explicit AttributeValue(Symbol name) : name(name) {}
106 using Ptr = std::unique_ptr<AttributeValue>;
107 Symbol name;
108 virtual AttributeKind kind() const = 0;
109 virtual Ptr clone() const = 0;
110 virtual ~AttributeValue() = default;
111};
112
113template <typename T, AttributeKind Kind>
114struct ScalarAttributeValue final : public AttributeValue {
115 using ConstructorType = const T&;
116 using ValueType = T;
117 ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {}
118 ValueType& value() {
119 return value_;
120 }
121 virtual Ptr clone() const override {
122 return Ptr(new ScalarAttributeValue(name, value_));
123 }
124 virtual AttributeKind kind() const override {
125 return Kind;
126 }
127
128 private:
129 ValueType value_;
130};
131
132template <typename T, AttributeKind Kind>
133struct VectorAttributeValue final : public AttributeValue {
134 using ConstructorType = const std::vector<T>&&;
135 using ValueType = std::vector<T>;
136 VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {}
137 ValueType& value() {
138 return value_;
139 }
140 virtual AttributeKind kind() const override {
141 return Kind;
142 }
143 virtual std::unique_ptr<AttributeValue> clone() const override {
144 auto copy = value_;
145 return Ptr(new VectorAttributeValue(name, std::move(copy)));
146 }
147
148 private:
149 ValueType value_;
150};
151
152using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
153using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
154using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
155using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
156using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
157using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
158using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>;
159using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>;
160using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
161using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
162using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>;
163using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>;
164
165// CRTP so that Node which inherits Attributes can be return for
166// method chaining e.g:
167// Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5);
168// we return Derived* pointers because Nodes are normally held as pointers.
169template <typename Derived>
170struct Attributes {
171 Attributes() {}
172 void copyAttributes(const Attributes& rhs) {
173 values_.clear();
174 values_.reserve(rhs.values_.size());
175 for (auto& i : rhs.values_) {
176 values_.push_back(i->clone());
177 }
178 }
179 bool hasAttribute(Symbol name) const {
180 return find(name, false) != values_.end();
181 }
182 AttributeKind kindOf(Symbol name) const {
183 return (*find(name, true))->kind();
184 }
185 Derived* removeAttribute(Symbol name) {
186 values_.erase(find(name, true));
187 return This();
188 }
189 bool hasAttributes() const {
190 return !values_.empty();
191 }
192 // The names are returned in order, since name actually is the index.
193 std::vector<Symbol> attributeNames() const {
194 std::vector<Symbol> names;
195 names.reserve(values_.size());
196 for (auto& a : values_)
197 names.push_back(a->name);
198 return names;
199 }
200
201#define CREATE_ACCESSOR(Kind, method) \
202 Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
203 return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \
204 } \
205 const Kind##Attr::ValueType& method(Symbol name) const { \
206 return get<Kind##Attr>(name); \
207 }
208 CREATE_ACCESSOR(Float, f)
209 CREATE_ACCESSOR(Floats, fs)
210 CREATE_ACCESSOR(String, s)
211 CREATE_ACCESSOR(Strings, ss)
212 CREATE_ACCESSOR(Int, i)
213 CREATE_ACCESSOR(Ints, is)
214 CREATE_ACCESSOR(Tensor, t)
215 CREATE_ACCESSOR(Tensors, ts)
216 CREATE_ACCESSOR(Graph, g)
217 CREATE_ACCESSOR(Graphs, gs)
218 CREATE_ACCESSOR(TypeProto, tp)
219 CREATE_ACCESSOR(TypeProtos, tps)
220
221#undef CREATE_ACCESSOR
222
223 private:
224 Derived* This() {
225 return static_cast<Derived*>(this);
226 }
227 template <typename T>
228 Derived* set(Symbol name, typename T::ConstructorType v) {
229 auto it = find(name, false);
230 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
231 if (it == values_.end()) {
232 values_.push_back(std::move(nv));
233 } else {
234 *it = std::move(nv);
235 }
236 return This();
237 }
238 template <typename T>
239 typename T::ValueType& get(Symbol name) const {
240 auto it = find(name, true);
241 T* child = static_cast<T*>(it->get());
242 return child->value();
243 }
244 using AVPtr = AttributeValue::Ptr;
245 // NB: For determinism, we use a vector rather than a hash map. This does
246 // mean that lookups are O(n), so you shouldn't use Attributes to store
247 // a big pile of messages.
248 std::vector<AVPtr> values_;
249 using iterator = std::vector<AVPtr>::iterator;
250 iterator find(Symbol name, bool required) {
251 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
252 ONNX_ASSERT(!required || it != values_.end());
253 return it;
254 }
255 using const_iterator = std::vector<AVPtr>::const_iterator;
256 const_iterator find(Symbol name, bool required) const {
257 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
258 ONNX_ASSERTM(
259 !required || it != values_.end(),
260 "%s:%u: %s: required undefined attribute '%s'",
261 __FILE__,
262 __LINE__,
263 __func__,
264 name.toString());
265 return it;
266 }
267};
268
269// Each use is represented by this type, see Node::uses()
270// 'user' is the consumer of the value, offset is the index into
271// 'user's input this where the produces will be found.
272struct Use final {
273 Use(Node* user, size_t offset) : user(user), offset(offset) {}
274 Node* user;
275 size_t offset;
276};
277
278static inline bool operator==(const Use& a, const Use& b) {
279 return a.user == b.user && a.offset == b.offset;
280}
281
282// the list types are intentionally simple, but we type-def
283// them here so if we need to change them, refactoring will be easier
284using node_list = std::vector<Node*>;
285using value_list = std::vector<Value*>;
286using use_list = std::vector<Use>;
287using NodeKind = Symbol;
288
289struct Value final {
290 ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
291 Value(Node* node_, size_t offset_);
292 Value(Value&&) = default;
293 Value& operator=(Value&&) = default;
294 ~Value() = default;
295
296 private:
297 friend struct Node;
298 friend struct Graph;
299 Node* node_;
300 size_t offset_;
301 size_t unique_ = 0; // unique id
302 size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
303 use_list uses_in_current_graph_;
304 bool has_unique_name_;
305 std::string unique_name_;
306 int32_t elem_type_;
307 bool has_sizes_;
308 std::vector<Dimension> sizes_;
309
310 public:
311 Value* setElemType(int32_t elem_type) {
312 elem_type_ = elem_type;
313 return this;
314 }
315 int32_t elemType() const {
316 return elem_type_;
317 }
318 bool has_sizes() const {
319 return has_sizes_;
320 }
321 Value* setSizes(std::vector<Dimension> sizes) {
322 has_sizes_ = true;
323 sizes_ = std::move(sizes);
324 return this;
325 }
326 Value* wipeSizes() {
327 has_sizes_ = false;
328 sizes_ = std::vector<Dimension>();
329 return this;
330 }
331 const std::vector<Dimension>& sizes() const {
332 return sizes_;
333 }
334 size_t unique() const {
335 return unique_;
336 }
337 bool has_unique_name() const {
338 return has_unique_name_;
339 }
340 std::string uniqueName() const {
341 if (has_unique_name())
342 return unique_name_;
343 return ONNX_NAMESPACE::to_string(unique());
344 }
345 Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true);
346 Value* setStage(size_t s) {
347 stage_ = s;
348 return this;
349 }
350 size_t stage() const {
351 return stage_;
352 }
353 Node* node() {
354 return node_;
355 }
356 size_t offset() const {
357 return offset_;
358 }
359 const Node* node() const {
360 return node_;
361 }
362 Graph* owningGraph();
363 const Graph* owningGraph() const;
364 // TODO: make this more const correct
365 const use_list uses() const;
366
367 // Replaces all uses of this node with 'newValue'.
368 //
369 // Given: %3 = f(%1, %2)
370 // %4 = g(%3)
371 // %5 = h(%3, %3)
372 // Execute: %3.replaceAllUsesWith(%6)
373 // Result: %3 = f(%1, %2)
374 // %4 = g(%6)
375 // %5 = h(%6, %6)
376 void replaceAllUsesWith(Value* newValue);
377
378 Value* copyMetadata(Value* from) {
379 setElemType(from->elemType());
380 setSizes(from->sizes());
381 if (from->has_unique_name()) {
382 setUniqueName(from->uniqueName());
383 }
384 return this;
385 }
386};
387
388struct Node : public Attributes<Node> {
389 ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
390 friend struct Graph;
391 friend struct Value;
392 friend graph_node_list;
393 friend const_graph_node_list;
394 friend graph_node_list_iterator;
395 friend const_graph_node_list_iterator;
396
397 private:
398 // each node but Return/Param
399 // is associated with exactly one place in the node list...
400 // of the graph_
401 // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the
402 // list such that the list never has null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev pointer
403 // using an array to allow the same iterator class for forward and reverse node lists
404 // This list represents a topological sort
405
406 Node* next_in_graph[2] = {nullptr, nullptr};
407 Node*& next() {
408 return next_in_graph[kNextDirection];
409 }
410 Node*& prev() {
411 return next_in_graph[kPrevDirection];
412 }
413 Node* const& next() const {
414 return next_in_graph[kNextDirection];
415 }
416 Node* const& prev() const {
417 return next_in_graph[kPrevDirection];
418 }
419
420 const NodeKind kind_;
421 std::vector<Value*> inputs_;
422 std::vector<Value*> outputs_;
423 Graph* graph_;
424 size_t stage_;
425 bool has_name_;
426 std::string name_;
427 bool has_domain_;
428 std::string domain_;
429 bool has_doc_string_;
430 std::string doc_string_;
431
432 protected:
433 Node(Graph* graph_, NodeKind kind_); // defined after graph
434
435 public:
436 bool has_name() const {
437 return has_name_;
438 }
439 const std::string& name() const {
440 return name_;
441 }
442 void setName(std::string name) {
443 has_name_ = true;
444 name_ = std::move(name);
445 }
446 bool has_domain() const {
447 return has_domain_;
448 }
449 const std::string& domain() const {
450 return domain_;
451 }
452 void setDomain(std::string domain) {
453 has_domain_ = true;
454 domain_ = std::move(domain);
455 }
456 bool has_doc_string() const {
457 return has_doc_string_;
458 }
459 const std::string& docString() const {
460 return doc_string_;
461 }
462 void setDocString(std::string doc_string) {
463 has_doc_string_ = true;
464 doc_string_ = std::move(doc_string);
465 }
466 NodeKind kind() const {
467 return kind_;
468 }
469 Graph* owningGraph() {
470 return graph_;
471 }
472 const Graph* owningGraph() const {
473 return graph_;
474 }
475 size_t stage() const {
476 return stage_;
477 }
478 Node* setStage(size_t s) {
479 stage_ = s;
480 return this;
481 }
482 // NB: This returns an ArrayRef; that means that it will
483 // get invalidated if you resize inputs (e.g., using addInput)
484 // We can't return a std::vector<Node*>& because there's no
485 // way to soundly cast to std::vector<const Node*> (an insane
486 // implementation of std::vector could make this representationally
487 // different.)
488 ArrayRef<Value*> inputs() {
489 return inputs_;
490 }
491 ArrayRef<const Value*> inputs() const {
492 // Vectors are not convertible in const-ness of elements, but
493 // raw pointers are.
494 return {inputs_.data(), inputs_.size()};
495 }
496 // NB: This returns an ArrayRef; that means that it will
497 // get invalidated if you resize inputs (e.g., using addInput)
498 // We can't return a std::vector<Node*>& because there's no
499 // way to soundly cast to std::vector<const Node*> (an insane
500 // implementation of std::vector could make this representationally
501 // different.)
502 ArrayRef<Value*> outputs() {
503 return outputs_;
504 }
505 ArrayRef<const Value*> outputs() const {
506 // Vectors are not convertible in const-ness of elements, but
507 // raw pointers are.
508 return {outputs_.data(), outputs_.size()};
509 }
510 bool hasUses() const {
511 for (auto o : outputs()) {
512 if (!o->uses().empty())
513 return true;
514 }
515 return false;
516 }
517 void replaceAllUsesWith(Node* n) {
518 ONNX_ASSERT(outputs().size() == n->outputs().size());
519 size_t nOutputs = outputs().size();
520 for (size_t i = 0; i < nOutputs; i++) {
521 outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
522 }
523 }
524 // lots of things like chunk have a single input or single output, so we have a
525 // helper to make accessing it easier
526 Value* input() {
527 ONNX_ASSERT(inputs_.size() == 1);
528 return inputs_.at(0);
529 }
530 Value* output() {
531 ONNX_ASSERT(outputs_.size() == 1);
532 return outputs_.at(0);
533 }
534 const Value* input() const {
535 ONNX_ASSERT(inputs_.size() == 1);
536 return inputs_.at(0);
537 }
538 Value* output() const {
539 ONNX_ASSERT(outputs_.size() == 1);
540 return outputs_.at(0);
541 }
542 // Access a particular input. This is a checked index.
543 Value* input(size_t i) {
544 return inputs_.at(i);
545 }
546 const Value* input(size_t i) const {
547 return inputs_.at(i);
548 }
549
550 // Graphs
551
552 // Note [Topological invariant]
553 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
554 // We always maintain an up-to-date topological ordering of all nodes via
555 // the next()/prev() links. All transformations to graphs must preserve
556 // this topological ordering: for example, it is only valid to 'addInput'
557 // with an input which is topologically before the current node.
558 //
559 // Usually, it is obvious whether or not topological order is maintained;
560 // for example, if you are adding nodes to the end of the topsort, it's
561 // impossible for them to refer to inputs that are not in the topsort.
562 // If it is not obvious, please comment accordingly.
563
564 // Add 'node' as an input to 'this' at the end of existing
565 // arguments. Returns the added node for ease of chaining.
566 //
567 // Given: %3 = f(%1, %2)
568 // Execute: %3.addInput(%4)
569 // Result: %3 = f(%1, %2, %4)
570 Value* addInput(Value* node) {
571 ONNX_ASSERT(graph_ == node->owningGraph());
572 node->uses_in_current_graph_.emplace_back(this, inputs_.size());
573 inputs_.push_back(node);
574 return node;
575 }
576
577 // Replace the input of 'this' at position 'i' with
578 // 'newValue', returning the old node.
579 //
580 // Given: %3 = f(%1, %2)
581 // Execute: %3.replaceInput(1, %4)
582 // Result: %3 = f(%1, %4)
583 Value* replaceInput(size_t i, Value* newValue) {
584 ONNX_ASSERT(newValue->owningGraph() == graph_);
585 Value* old = dropInput(i);
586 inputs_[i] = newValue;
587 newValue->uses_in_current_graph_.emplace_back(this, i);
588 return old;
589 }
590
591 // Replace all occurrences of 'from' in the inputs of this
592 // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
593 //
594 // Given: %3 = f(%1, %2, %1)
595 // Execute: %3.replaceInputWith(%1, %4)
596 // Result: %3 = f(%4, %2, %4)
597 void replaceInputWith(Value* from, Value* to) {
598 ONNX_ASSERT(from->owningGraph() == graph_);
599 ONNX_ASSERT(to->owningGraph() == graph_);
600 size_t i = 0;
601 for (auto input : inputs()) {
602 if (input == from)
603 replaceInput(i, to);
604 i++;
605 }
606 }
607
608 Value* addOutput() {
609 outputs_.push_back(new Value(this, outputs_.size()));
610 return outputs_.back();
611 }
612
613 void eraseOutput(size_t i);
614
615 // Insert unattached 'this' node after 'n' in the topological order.
616 // Returns this (for chaining).
617 //
618 // Given: %3 = f(%1, %2)
619 // %4 = g(%3)
620 // and unattached: %5 = h(%1)
621 // Execute: %5.insertBefore(%4)
622 // Result: %3 = f(%1, %2)
623 // %5 = h(%1)
624 // %4 = g(%3)
625 Node* insertBefore(Node* n) {
626 ONNX_ASSERT(n->inGraphList());
627 insertAfter(n->prev());
628 return this;
629 }
630
631 // Insert unattached 'this' node after 'n' in the topological order.
632 // Returns this (for chaining).
633 //
634 // Given: %3 = f(%1, %2)
635 // %4 = g(%3)
636 // and unattached: %5 = h(%1)
637 // Execute: %5.insertAfter(%4)
638 // Result: %3 = f(%1, %2)
639 // %4 = g(%3)
640 // %5 = h(%1)
641 Node* insertAfter(Node* n) {
642 ONNX_ASSERT(!inGraphList() && n->inGraphList());
643 Node* next = n->next();
644 n->next() = this;
645 this->prev() = n;
646 this->next() = next;
647 next->prev() = this;
648 return this;
649 }
650
651 // Move 'this' (already in the graph) after 'n' in the topological order.
652 //
653 // Given: %2 = f(%1)
654 // %3 = g(%1)
655 // Execute: %2.moveAfter(%3)
656 // Result: %3 = g(%1)
657 // %2 = f(%1)
658 //
659 void moveAfter(Node* n) {
660 removeFromList();
661 insertAfter(n);
662 }
663
664 // Move a node 'n' (already in the graph) before 'this' in the topological order.
665 //
666 // Given: %2 = f(%1)
667 // %3 = g(%1)
668 // Execute: %3.moveBefore(%2)
669 // Result: %3 = g(%1)
670 // %2 = f(%1)
671 void moveBefore(Node* n) {
672 removeFromList();
673 insertBefore(n);
674 }
675
676 // Remove the input at 'i' from this node.
677 //
678 // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
679 // removeInput.
680 //
681 // Given: %3 = f(%1, %2)
682 // Execute: %3.removeInput(1)
683 // Result: %3 = f(%1)
684 void removeInput(size_t i) {
685 dropInput(i);
686 // everything after this input shifts left,
687 // so we need to update their use offsets to match
688 for (size_t j = i + 1; j < inputs_.size(); j++) {
689 auto it = findUseForInput(j);
690 it->offset--;
691 }
692 inputs_.erase(inputs_.begin() + i);
693 }
694
695 // Remove all inputs from a node.
696 //
697 // Given: %3 = f(%1, %2)
698 // Execute: %3.removeAllInputs()
699 // Result: %3 = f()
700 void removeAllInputs() {
701 for (size_t i = 0; i < inputs().size(); ++i)
702 dropInput(i);
703 inputs_.clear();
704 }
705
706 // Check whether this node is before node n in the graph.
707 bool isBefore(Node* n);
708
709 // iterators of the node list starting at this node
710 // useful for resuming a search starting at this node
711 graph_node_list_iterator iterator();
712 graph_node_list_iterator reverseIterator();
713 const_graph_node_list_iterator iterator() const;
714 const_graph_node_list_iterator reverseIterator() const;
715
716 // Remove 'this' from the instruction list and deallocate it.
717 //
718 // Invariant: no outputs of 'this' may have any uses.
719 //
720 // Given: %2 = f(%1)
721 // %3 = g(%1)
722 // Execute: %2.destroy()
723 // Result: %3 = g(%1)
724 void destroy();
725
726 // Dynamically cast this node to the subclass indicated by the
727 // template variable, returning nullptr if the cast is invalid..
728 //
729 // Example usage: if(auto s = n.cast<Select>()) { ... }
730 //
731 // TODO: Make this const correct
732 template <typename T>
733 T* cast() {
734 if (T::Kind == kind())
735 return static_cast<T*>(this);
736 return nullptr;
737 }
738 template <typename T>
739 T* expect() {
740 ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
741 return static_cast<T*>(this);
742 }
743
744 virtual ~Node() = default;
745
746 private:
747 // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
748 use_list::iterator findUseForInput(size_t i) {
749 auto& input_uses = inputs_[i]->uses_in_current_graph_;
750 // O(N) on the use list, but unless we get nodes with +100 uses
751 // vector traversal still is probably faster than linked list
752 auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
753 ONNX_ASSERT(use_it != input_uses.end());
754 return use_it;
755 }
756
757 // remove the use of input i, this sets input i to nullptr, but
758 // is only used internally to Node before setting it to a new value
759 // or erasing the entry from the list.
760 Value* dropInput(size_t i) {
761 ONNX_ASSERT(i < inputs_.size());
762 auto input_node = inputs_[i];
763 auto use_it = findUseForInput(i);
764 input_node->uses_in_current_graph_.erase(use_it);
765 inputs_[i] = nullptr;
766 return input_node;
767 }
768
769 bool inGraphList() const {
770 ONNX_ASSERT(next() != nullptr || prev() == nullptr);
771 return next() != nullptr;
772 }
773 void removeFromList() {
774 ONNX_ASSERT(inGraphList());
775 Node* next = this->next();
776 Node* prev = this->prev();
777 prev->next() = next;
778 next->prev() = prev;
779 this->next() = nullptr;
780 this->prev() = nullptr;
781 }
782
783 protected:
784 // subclasses must override
785 // this function is used by createClone to initialize a new version
786 // of a node in another graph. It should allocate a new instance of the same
787 // concrete type as 'this', but in graph 'g' which might be different
788 // than graph_
789 virtual Node* allocNewInstance(Graph* g) {
790 return new Node(g, kind());
791 }
792 // create a copy of all properties of Node s into this.
793 // subclasses should extend if they have additional information to copy.
794 // 'this' will be allocated with s->allocNewInstance(g) so it should have
795 // the same concrete type as 's'
796 //
797 // NB: This does NOT clone stages. You're expected to set the stage correctly
798 // if you are going to preserve it.
799 virtual void cloneFrom(Node* s) {
800 copyAttributes(*s);
801 }
802};
803
804// A class with the same properties as OperatorSetIdProto, but without protobuf
805// overhead, resulting in a simpler and more readable workflow.
806class OpSetID final {
807 private:
808 std::string domain_;
809 int64_t version_;
810
811 public:
812 explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {}
813
814 // Default Domain Constructor
815 explicit OpSetID(const int64_t version) : domain_(""), version_(version) {}
816
817 explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {}
818
819 // target must be in the form "<domain>&<version>"
820 std::string toString() const {
821 return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
822 }
823
824 // target must be in the form "<domain>&<version>"
825 static OpSetID fromString(const std::string& target) {
826 ONNX_TRY {
827 std::string new_domain = target.substr(0, target.find("$"));
828 int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
829 return OpSetID(new_domain, new_version);
830 }
831 ONNX_CATCH(const std::runtime_error& e) {
832 ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s", e.what()); });
833 }
834
835 // The control will never reach here.
836 // In the default build where exceptions are turned on in case of any error
837 // the control will enter catch block where an exception will be thrown again.
838 // In case of "no exception build" the code aborts at the site of first exception.
839 // Adding this to appease the warning "control may reach end of non-void function"
840 // as the mac build fails when ONNX_WERROR==ON
841 return OpSetID("", 0);
842 }
843
844 const std::string& domain() const {
845 return domain_;
846 }
847
848 int64_t version() const {
849 return version_;
850 }
851
852 void incrementVersion(int64_t step) {
853 version_ += step;
854 }
855
856 void setVersion(int64_t newVal) {
857 version_ = newVal;
858 }
859};
860
861struct Graph final {
862 ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
863 friend struct Node;
864 friend struct Value;
865
866 private:
867 // only used to keep track of allocated nodes
868 // actual representation of Graph is done with
869 // inputs, outputs, nodes
870
871 std::unordered_set<const Node*> all_nodes;
872 std::unordered_set<const Value*> all_values;
873 size_t next_unique_;
874
875 size_t new_node_stage_;
876
877 // holds outputs in a way that can be reflected
878 // as a Use object
879 // also used as the beginning/end of the circular node list to avoid
880 // having corner cases where the list is empty.
881 Node* const output_;
882 Node* const input_;
883 // Create an independent node list for those initializers do not exist in input
884 Node* const initializer_node_;
885
886 std::vector<Tensor> initializers_;
887 std::vector<std::string> initializer_names_;
888
889 bool has_name_;
890 std::string name_;
891 bool has_doc_string_;
892 std::string doc_string_;
893
894 std::vector<OpSetID> opset_versions_;
895
896 bool isNameUnique(const std::string& name) const {
897 if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) {
898 return false;
899 }
900 const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
901 for (const Node* node : all_nodes) {
902 for (const auto& attr : node->attributeNames()) {
903 if (node->kindOf(attr) == AttributeKind::g) {
904 const auto& subgraph = node->g(attr);
905 if (!subgraph->isNameUnique(name)) {
906 return false;
907 }
908 } else if (node->kindOf(attr) == AttributeKind::gs) {
909 for (const auto& subgraph : node->gs(attr)) {
910 if (!subgraph->isNameUnique(name)) {
911 return false;
912 }
913 }
914 }
915 }
916 const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f);
917 if (found_in != node->inputs().end()) {
918 return false;
919 }
920 const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f);
921 if (found_out != node->outputs().end()) {
922 return false;
923 }
924 }
925 return true;
926 }
927
928 public:
929 Graph()
930 : next_unique_(0),
931 new_node_stage_(0),
932 output_(initOutput(create(kReturn, 0))),
933 input_(create(kParam, 0)),
934 initializer_node_(create(kParam, 0)),
935 has_name_(false),
936 has_doc_string_(false) {}
937
938 bool has_doc_string() const {
939 return has_doc_string_;
940 }
941 const std::string& docString() {
942 return doc_string_;
943 }
944 void setDocString(std::string doc_string) {
945 has_doc_string_ = true;
946 doc_string_ = std::move(doc_string);
947 }
948
949 void addInitializer(Tensor& initializer) {
950 if (initializer.name().empty()) {
951 initializer.setName(ONNX_NAMESPACE::to_string(getNextUnique()));
952 }
953 initializers_.push_back(initializer);
954 initializer_names_.push_back(initializer.name());
955 }
956
957 // For IR >= 4, initializer is not required to exist in input
958 // Add initializer into initializer node list and return its Value
959 Value* addInitializerAndCreateValue(Tensor& initializer) {
960 addInitializer(initializer);
961 auto* init_value = initializer_node_->addOutput();
962 std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()};
963 init_value->setUniqueName(initializer.name());
964 init_value->setSizes(dim_sizes);
965 init_value->setElemType(initializer.elem_type());
966 return init_value;
967 }
968
969 void eraseInitializer(const std::string& name) {
970 initializers_.erase(
971 std::remove_if(
972 initializers_.begin(),
973 initializers_.end(),
974 [&name](Tensor& initializer) { return initializer.name() == name; }),
975 initializers_.end());
976 initializer_names_.erase(
977 std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end());
978 for (size_t i = 0; i < initializer_node_->outputs().size(); i++) {
979 if (initializer_node_->outputs()[i]->uniqueName() == name) {
980 initializer_node_->eraseOutput(i);
981 break;
982 }
983 }
984 }
985 void clearInitializers() {
986 initializers_.clear();
987 initializer_names_.clear();
988 }
989 const std::vector<Tensor>& initializers() const {
990 return initializers_;
991 }
992 const std::vector<std::string>& initializer_names() const {
993 return initializer_names_;
994 }
995 std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
996 for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
997 if (name == it->name()) {
998 return it;
999 }
1000 }
1001 return initializers_.end();
1002 }
1003 bool is_constant_initializer(const Value* value) const {
1004 return value->node() == initializer_node_;
1005 }
1006 ArrayRef<Value*> inputs() {
1007 return input_->outputs();
1008 }
1009 ArrayRef<const Value*> inputs() const {
1010 const auto& inputs = input_->outputs();
1011 return {inputs.data(), inputs.size()};
1012 }
1013 ArrayRef<Value*> outputs() {
1014 return output_->inputs();
1015 }
1016 ArrayRef<const Value*> outputs() const {
1017 return static_cast<const Node*>(output_)->inputs();
1018 }
1019 graph_node_list nodes() {
1020 return graph_node_list(output_, kNextDirection);
1021 }
1022 const_graph_node_list nodes() const {
1023 return const_graph_node_list(output_, kNextDirection);
1024 }
1025
1026 std::vector<OpSetID>& opset_versions_mutable() {
1027 return opset_versions_;
1028 }
1029
1030 size_t getNextUnique() {
1031 std::string next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1032 while (!isNameUnique(next_unique_name)) {
1033 next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1034 }
1035 return next_unique_;
1036 }
1037
1038 // These invocations of begin() on output of function are OK
1039 // because graph_node_list is non-owning, so it doesn't matter
1040 // if it immediately dies after the invocation.
1041 graph_node_list_iterator begin() {
1042 return nodes().begin();
1043 }
1044 const_graph_node_list_iterator begin() const {
1045 return nodes().begin();
1046 }
1047 graph_node_list_iterator end() {
1048 return nodes().end();
1049 }
1050 const_graph_node_list_iterator end() const {
1051 return nodes().end();
1052 }
1053 graph_node_list_iterator rbegin() {
1054 return nodes().rbegin();
1055 }
1056 const_graph_node_list_iterator rbegin() const {
1057 return nodes().rbegin();
1058 }
1059 graph_node_list_iterator rend() {
1060 return nodes().rend();
1061 }
1062 const_graph_node_list_iterator rend() const {
1063 return nodes().rend();
1064 }
1065 Node* return_node() {
1066 return output_;
1067 }
1068 const Node* return_node() const {
1069 return output_;
1070 }
1071
1072 Value* addInput() {
1073 return input_->addOutput();
1074 }
1075 void eraseInput(size_t i) {
1076 input_->eraseOutput(i);
1077 }
1078 void advanceStage() {
1079 new_node_stage_++;
1080 }
1081 void setStage(size_t new_stage) {
1082 new_node_stage_ = new_stage;
1083 }
1084 size_t stage() const {
1085 return new_node_stage_;
1086 }
1087 ResourceGuard setStageTemporary(size_t s) {
1088 auto prev_stage = new_node_stage_;
1089 new_node_stage_ = s;
1090 return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1091 }
1092
1093 size_t registerOutput(Value* n) {
1094 output_->addInput(n);
1095 return outputs().size() - 1;
1096 }
1097
1098 Node* create(NodeKind kind, size_t num_outputs = 1) {
1099 // NB: Node constructor adds node to all_nodes
1100 auto n = new Node(this, kind);
1101 for (size_t i = 0; i < num_outputs; i++)
1102 n->addOutput();
1103 return n;
1104 }
1105
1106 Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) {
1107 auto n = create(kind, num_outputs);
1108 for (auto i : inputs)
1109 n->addInput(i);
1110 return n;
1111 }
1112
1113 Node* appendNode(Node* n) {
1114 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1115 n->insertBefore(output_);
1116 return n;
1117 }
1118
1119 Node* prependNode(Node* n) {
1120 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1121 n->insertAfter(output_);
1122 return n;
1123 }
1124
1125 // Adds to graph initializer list, initializer names list, and as a graph input
1126 // Also syncs the initializer name, tensor name, and value name
1127 // Create an initializer whose value is stored in input
1128 Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) {
1129 Tensor initializerCopy = initializer;
1130 std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()};
1131 Value* new_init = addInput();
1132 initializerCopy.setName(name);
1133 new_init->setUniqueName(name);
1134 new_init->setSizes(dim_sizes);
1135 new_init->setElemType(initializerCopy.elem_type());
1136 addInitializer(initializerCopy);
1137 return new_init;
1138 }
1139
1140 Value* addInitializerAndInput(const Tensor& initializer) {
1141 return addInitializerAndInput(initializer, ONNX_NAMESPACE::to_string(getNextUnique()));
1142 }
1143
1144 // Erases from graph initializer list, initializer names list, and as a graph input
1145 // Must have no uses
1146 void eraseInitializerAndInput(Value* v) {
1147 eraseInitializer(v->uniqueName());
1148 if (v->node() == input_) {
1149 eraseInput(v->offset());
1150 }
1151 }
1152
1153 ~Graph() {
1154 for (const Node* n : all_nodes)
1155 delete n;
1156 for (const Value* v : all_values)
1157 delete v;
1158 }
1159
1160 std::string toString() const {
1161 std::ostringstream oss;
1162 oss << *this;
1163 return oss.str();
1164 }
1165
1166 bool has_name() const {
1167 return has_name_;
1168 }
1169
1170 const std::string& name() const {
1171 return name_;
1172 }
1173
1174 void setName(std::string name) {
1175 has_name_ = true;
1176 name_ = std::move(name);
1177 }
1178
1179 friend std::ostream& operator<<(std::ostream& out, const Graph& g);
1180
1181 void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) {
1182 fn(this);
1183
1184 for (const Node* node : all_nodes) {
1185 for (const auto& attr : node->attributeNames()) {
1186 if (node->kindOf(attr) == AttributeKind::g) {
1187 std::shared_ptr<Graph> subgraph = node->g(attr);
1188 subgraph->forSelfAndEachSubGraph(fn);
1189 } else if (node->kindOf(attr) == AttributeKind::gs) {
1190 for (const auto& subgraph : node->gs(attr)) {
1191 subgraph->forSelfAndEachSubGraph(fn);
1192 }
1193 }
1194 }
1195 }
1196 }
1197
1198 void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const {
1199 std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); };
1200 const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1201 }
1202
1203 void forEachNode(const std::function<void(Node*)>& fn) {
1204 forSelfAndEachSubGraph([fn](Graph* graph) {
1205 for (Node* node : graph->nodes()) {
1206 fn(node);
1207 }
1208 });
1209 }
1210
1211 void forEachNode(const std::function<void(const Node*)>& fn) const {
1212 std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); };
1213 const_cast<Graph*>(this)->forEachNode(tmp_fn);
1214 }
1215
1216 private:
1217 // should only be called in the constructor
1218 Node* initOutput(Node* p) {
1219 p->next() = p;
1220 p->prev() = p;
1221 p->setStage(std::numeric_limits<size_t>::max());
1222 return p;
1223 }
1224
1225 void freeNode(Node* n) {
1226 auto it = all_nodes.find(n);
1227 ONNX_ASSERT(it != all_nodes.end());
1228 delete *it;
1229 all_nodes.erase(it);
1230 }
1231 void freeValue(Value* v) {
1232 auto it = all_values.find(v);
1233 ONNX_ASSERT(it != all_values.end());
1234 delete *it;
1235 all_values.erase(it);
1236 }
1237};
1238
1239inline Value::Value(Node* node_, size_t offset_)
1240 : node_(node_),
1241 offset_(offset_),
1242 unique_(node_->graph_->getNextUnique()),
1243 stage_(node_->graph_->new_node_stage_),
1244 has_unique_name_(false),
1245 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1246 has_sizes_(false) {
1247 node_->graph_->all_values.emplace(this);
1248}
1249
1250inline Graph* Value::owningGraph() {
1251 return node()->owningGraph();
1252}
1253
1254inline const Graph* Value::owningGraph() const {
1255 return node()->owningGraph();
1256}
1257
1258// `captured` nodes in subgraph determines which value it captures
1259// by storing the value's unique name, so old unique names in `captured` nodes
1260// should also be updated.
1261// Initializer names are also storaged in graph.initializer_names_, it should be
1262// updated too.
1263inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) {
1264 if (has_unique_name() && update_related_names) {
1265 auto* graph = owningGraph();
1266 auto old_name = unique_name_;
1267 for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) {
1268 auto& initializer_name = owningGraph()->initializer_names_[i];
1269 if (initializer_name == old_name) {
1270 initializer_name = name;
1271 owningGraph()->initializers_[i].setName(name);
1272 }
1273 }
1274 graph->forEachNode([this, &name, &old_name](Node* node) {
1275 if (node->owningGraph() == this->owningGraph()) {
1276 // skip non-subgraph
1277 return;
1278 }
1279 if (node->kind() == kCaptured) {
1280 Value* output = node->output();
1281 if (output->uniqueName() == old_name) {
1282 output->setUniqueName(name, false);
1283 }
1284 }
1285 });
1286 }
1287 unique_name_ = name;
1288 has_unique_name_ = true;
1289 return this;
1290}
1291
1292inline void Value::replaceAllUsesWith(Value* newValue) {
1293 auto* graph = owningGraph();
1294 ONNX_ASSERT(graph == newValue->owningGraph());
1295 // propagate sizes and elem type
1296 if (this->has_sizes()) {
1297 newValue->setSizes(this->sizes());
1298 }
1299 if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1300 newValue->setElemType(this->elemType());
1301 }
1302 const auto unique_name = this->uniqueName();
1303 // We do not want the optimization to change the graph output name
1304 if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) {
1305 newValue->setUniqueName(unique_name);
1306 // The "unique" semantic of unique_name should be kept or uses()
1307 // will return an incorrect result when the value is used in subgraph
1308 this->setUniqueName(ONNX_NAMESPACE::to_string(graph->getNextUnique()), false);
1309 }
1310 newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1311 for (auto u : uses_in_current_graph_) {
1312 u.user->inputs_[u.offset] = newValue;
1313 newValue->uses_in_current_graph_.push_back(u);
1314 }
1315 graph->forEachNode([this, &newValue, &unique_name](Node* node) {
1316 if (node->owningGraph() == this->owningGraph()) {
1317 // skip non-subgraph
1318 return;
1319 }
1320 if (node->kind() == kCaptured) {
1321 Value* output = node->output();
1322 if (output->uniqueName() == unique_name) {
1323 output->setUniqueName(newValue->uniqueName());
1324 }
1325 }
1326 });
1327 uses_in_current_graph_.clear();
1328 assert(this->uses().empty());
1329}
1330
1331inline Node::Node(Graph* graph_, NodeKind kind_)
1332 : kind_(kind_),
1333 graph_(graph_),
1334 stage_(graph_->new_node_stage_),
1335 has_name_(false),
1336 has_domain_(false),
1337 has_doc_string_(false) {
1338 graph_->all_nodes.emplace(this);
1339}
1340
1341inline void Node::eraseOutput(size_t i) {
1342 ONNX_ASSERT(i < outputs_.size());
1343 ONNX_ASSERT(outputs_[i]->uses().empty());
1344 Value* n = outputs_[i];
1345 outputs_.erase(outputs_.begin() + i);
1346 owningGraph()->freeValue(n);
1347 for (size_t j = i; j < outputs_.size(); j++) {
1348 outputs_[j]->offset_--;
1349 }
1350}
1351
1352inline bool Node::isBefore(Node* n) {
1353 if (n == nullptr || this == n) {
1354 // Bail out early.
1355 return false;
1356 }
1357 // return true if node is Param (in initializers)
1358 if (kind_ == kParam) {
1359 return true;
1360 }
1361 // return false if target node is Param (in initializers)
1362 if (n->kind() == kParam) {
1363 return false;
1364 }
1365 ONNX_ASSERT(n->inGraphList());
1366 for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1367 if (p == n) {
1368 return true;
1369 }
1370 }
1371 return false;
1372}
1373
1374inline void Node::destroy() {
1375 ONNX_ASSERT(inGraphList());
1376 while (!outputs().empty())
1377 eraseOutput(outputs().size() - 1);
1378 removeAllInputs();
1379 removeFromList();
1380 graph_->freeNode(this);
1381}
1382
1383/************* All nodes not required to be defined before Graph **************/
1384
1385inline graph_node_list_iterator Node::iterator() {
1386 return graph_node_list_iterator(this, 0);
1387}
1388inline graph_node_list_iterator Node::reverseIterator() {
1389 return iterator().reverse();
1390}
1391inline const_graph_node_list_iterator Node::iterator() const {
1392 return const_graph_node_list_iterator(this, 0);
1393}
1394inline const_graph_node_list_iterator Node::reverseIterator() const {
1395 return iterator().reverse();
1396}
1397
1398// Returns a list about which nodes are using this value,
1399// nodes in subgraph are also included.
1400// This method is usually used to check whether it is
1401// safe to delete a Value.
1402inline const use_list Value::uses() const {
1403 use_list all_uses = uses_in_current_graph_;
1404 owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1405 if (node->owningGraph() == this->owningGraph()) {
1406 // skip non-subgraph
1407 return;
1408 }
1409 if (node->kind() == kCaptured) {
1410 const Value* output = node->outputs()[0];
1411 if (output->uniqueName() == this->uniqueName()) {
1412 const auto output_uses = output->uses();
1413 all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1414 }
1415 }
1416 });
1417 return all_uses;
1418}
1419
1420} // namespace ONNX_NAMESPACE
1421