1#pragma once
2
3#include <torch/csrc/jit/ir/attributes.h>
4#include <torch/csrc/jit/ir/graph_node_list.h>
5#include <torch/csrc/jit/ir/named_value.h>
6#include <torch/csrc/jit/ir/scope.h>
7#include <torch/csrc/jit/runtime/operator.h>
8
9#include <torch/csrc/Export.h>
10#include <torch/csrc/utils/python_stub.h>
11#include <torch/csrc/utils/schema_info.h>
12
13#include <ATen/Utils.h>
14#include <ATen/core/Tensor.h>
15#include <ATen/core/dynamic_type.h>
16#include <ATen/core/enum_type.h>
17#include <ATen/core/functional.h>
18#include <ATen/core/interned_strings.h>
19#include <ATen/core/ivalue.h>
20#include <ATen/core/jit_type.h>
21#include <c10/util/ArrayRef.h>
22#include <c10/util/Exception.h>
23#include <c10/util/Optional.h>
24
25#include <functional>
26#include <iostream>
27#include <unordered_set>
28#include <vector>
29
30// Forward declare, the real meat is in python_ir.cpp
31template <class T>
32class THPPointer;
33using THPObjectPtr = THPPointer<PyObject>;
34using pyobj_list = std::vector<THPObjectPtr>;
35
36namespace torch {
37namespace jit {
38namespace utils {
39TORCH_API std::string getNodesModuleHierarchy(const Node& n);
40} // namespace utils
41class AliasDb;
42
43using ::c10::Argument;
44using ::c10::FunctionSchema;
45using ::c10::Symbol;
46
47using ::c10::ivalue::Shared;
48
49using ::c10::IValue;
50using ::c10::ivalue::Future;
51
52using ::c10::ivalue::ConstantString;
53
54#define C10_USING(T) using ::c10::T;
55C10_FORALL_TYPES(C10_USING)
56#undef C10_USING
57
58#define C10_USING(T) using ::c10::T##Ptr;
59C10_FORALL_TYPES(C10_USING)
60#undef C10_USING
61
62using ::c10::Type;
63using ::c10::TypeEnv;
64using ::c10::TypePtr;
65
66using ::c10::getTypePtr;
67using ::c10::MatchTypeReturn;
68using ::c10::TypeKind;
69
70using ::c10::fmap;
71
72namespace prim {
73using namespace ::c10::prim;
74}
75namespace attr {
76using namespace ::c10::attr;
77}
78namespace aten {
79using namespace ::c10::aten;
80}
81namespace cuda {
82#if !defined(USE_ROCM)
83using namespace ::c10::cuda;
84#endif
85} // namespace cuda
86
87struct Function;
88struct GraphFunction;
89struct MatchedSchema;
90
91// A Graph represents one "function" of computation.
92// It uses a simple ownership model where the graph owns all the nodes inside
93// it. All references inside the graph are raw pointers. Destroying the Graph
94// will invalidate any pointers to nodes in the graph.
95struct Graph;
96
97// Node is the base class of the IR graph. It represents one computation
98// and dependencies on a list of Values. The "prim-ops", so to speak.
99struct Node;
100
101// A Value represents an input or output to node that is either a
102// Tensor or an opaque Handle object, as determined by type().
103struct Value;
104
105TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
106TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n);
107
108// A list of nodes, with inputs and outputs
109struct Block;
110
111// Each use is represented by this type, see 'Node::uses()'
112// 'user' is the consumer of the value, 'offset' is the index into
113// 'user's input this where the producers will be found.
114struct Use {
115 Use(Node* user, size_t offset) : user(user), offset(offset) {}
116 Node* user;
117 size_t offset;
118
119 bool operator==(const Use& b) {
120 return user == b.user && offset == b.offset;
121 }
122};
123
124// Note [User node does not uniquely identify use]
125// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
126// A while back, we wrote some code manipulating uses that looked like this:
127//
128// for (auto& use : used_val->uses_) {
129// if (use.user == this_node) {
130// use.offset += 1;
131// break;
132// }
133// }
134//
135// This code is trying to find a particular use (our node's use) to update it.
136// However, it's wrong: there may be *multiple* uses of a value %x in a node,
137// as might be the case in this IR:
138//
139// %y = Add %x %x
140//
141// In this case, there are two uses of %x whose user is the node 'Add %x %x'.
142// So, "use induced by this node" is not a well-formed concept.
143//
144// If you are looking for "use induced by an input", it's best to use
145// findUseForInput() to get it.
146
147// the list types are intentionally simple, but we type-def
148// them here so if we need to change them, refactoring will be easier
149using node_list = std::vector<Node*>;
150using value_list = std::vector<Value*>;
151using use_list = std::vector<Use>;
152template <typename T>
153using ArrayRef = at::ArrayRef<T>;
154using NodeKind = Symbol;
155using topo_position_t = int64_t;
156using ValueSet = std::unordered_set<const Value*>;
157
158struct OperatorSet;
159template <typename T>
160struct OperatorMap;
161
162// This is a wrapper to allow invalidating the Python object
163// safely when the C++ object for a Node/Value/Block is deleted
164// like much of graph, it isn't safe for different threads to
165// access the same graph
166template <typename T>
167struct Wrap {
168 explicit Wrap(T* p) : elem(p), clear_cb(nullptr) {}
169 void clear() {
170 if (clear_cb) {
171 clear_cb(elem);
172 }
173 elem = nullptr;
174 }
175 T* elem;
176 void (*clear_cb)(void*);
177};
178
179struct Value {
180 AT_DISALLOW_COPY_AND_ASSIGN(Value);
181 Value(Node* node_, size_t offset_);
182
183 private:
184 friend struct Node;
185 friend struct Graph;
186 Node* node_;
187 size_t offset_;
188 size_t unique_ = 0; // unique id
189 use_list uses_;
190 std::string unique_name_;
191 TypePtr type_;
192 // a managing wrapper for Python to allow invalidation
193 std::shared_ptr<Wrap<Value>> wrap_;
194
195 public:
196 Value* setType(TypePtr type);
197 TORCH_API void inferTypeFrom(const at::Tensor& output);
198 TORCH_API void inferTypeFrom(
199 const c10::intrusive_ptr<c10::ivalue::Object>& output);
200 const TypePtr& type() const {
201 AT_ASSERT(type_ != nullptr);
202 return type_;
203 }
204 bool requires_grad() const {
205 return type()->requires_grad();
206 }
207 bool isCompleteTensor() const {
208 if (auto pt = type()->cast<TensorType>()) {
209 return pt->isComplete();
210 }
211 return false;
212 }
213 TORCH_API bool mustBeNone() const;
214 TORCH_API bool mustNotBeNone() const;
215 size_t unique() const {
216 return unique_;
217 }
218 bool hasDebugName() const {
219 return !unique_name_.empty();
220 }
221 static bool isValidName(const std::string& name);
222 TORCH_API Value* setDebugName(const std::string& name);
223 std::string debugName() const {
224 if (hasDebugName()) {
225 return unique_name_;
226 }
227 return c10::to_string(unique());
228 }
229 TORCH_API std::string debugNameBase() const;
230 Node* node() {
231 return node_;
232 }
233 size_t offset() const {
234 return offset_;
235 }
236 void setOffset(size_t offset) {
237 offset_ = offset;
238 }
239 const Node* node() const {
240 return node_;
241 }
242
243 /**
244 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
245 * Check #87343 for details.
246 */
247 Graph* owningGraph();
248 const Graph* owningGraph() const;
249 // TODO: make this more const correct
250 const use_list& uses() const {
251 return uses_;
252 }
253
254 bool hasUses() const {
255 return !uses().empty();
256 }
257
258 TORCH_API void replaceFirstUseWith(Value* newValue);
259
260 // Replaces all uses of this value with 'newValue'.
261 //
262 // Given: %3 = f(%1, %2)
263 // %4 = g(%3)
264 // %5 = h(%3, %3)
265 // Execute: %3.replaceAllUsesWith(%6)
266 // Result: %3 = f(%1, %2)
267 // %4 = g(%6)
268 // %5 = h(%6, %6)
269 TORCH_API void replaceAllUsesWith(Value* newValue);
270
271 // Replaces all uses of this value with 'newValue' after 'node'.
272 // Given: %3 = f(%1, %2)
273 // %4 = g(%3)
274 // %5 = inplace_(%3)
275 // %6 = h(%3, %3)
276 // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5)
277 // Result: %3 = f(%1, %2)
278 // %4 = g(%3)
279 // %5 = inplace_(%3)
280 // %6 = h(%5, %5)
281 // XXX: does not check scoping legality, consider using
282 // replaceAllUsesDominatedByNodeWith
283 TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue);
284
285 // Replaces all uses of this value with 'newValue' that are dominated by
286 // 'node'. Given:
287 // x = op(...).
288 // if cond:
289 // z = foo(..)
290 // bar(x)
291 // else:
292 // print(x)
293 // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x)
294 // but not print(x) because print is not dominated by foo.
295 // replaceAllUsesAfterNode does not check domination, so in this example
296 // it would produce invalid IR.
297 TORCH_API void replaceAllUsesDominatedByNodeWith(
298 const Node* node,
299 Value* newValue);
300
301 TORCH_API Value* copyMetadata(Value* from);
302
303 TORCH_API std::shared_ptr<Wrap<Value>> wrap() {
304 if (!wrap_) {
305 wrap_ = std::make_shared<Wrap<Value>>(this);
306 }
307 return wrap_;
308 }
309
310 virtual ~Value() {
311 if (wrap_) {
312 wrap_->clear();
313 }
314 }
315};
316
317struct TORCH_API Node {
318 AT_DISALLOW_COPY_AND_ASSIGN(Node);
319 friend struct Graph;
320 friend struct Block;
321 friend struct Value;
322 friend graph_node_list;
323 friend const_graph_node_list;
324 friend graph_node_list_iterator;
325 friend const_graph_node_list_iterator;
326
327 private:
328 const NodeKind kind_;
329 std::vector<Value*> inputs_;
330 std::vector<Value*> outputs_;
331 // subblocks
332 std::vector<Block*> blocks_;
333 Graph* graph_;
334 Block* owning_block_;
335 c10::optional<SourceRange> source_range_;
336 ScopePtr scope_;
337 c10::optional<InlinedCallStackPtr> callstack_;
338 // Assumes FunctionSchemas are persistent, so we don't manage their lifetime.
339 // This field is effective a cache that's populated on attribute lookups and
340 // invalidated every time we perform an operation that could potentially
341 // change the schema. note: mutable because schema_ is effectively a cache
342 mutable const Operator* op_;
343 topo_position_t topo_position_ = 0;
344 // a managing wrapper for Python to allow invalidation
345 std::shared_ptr<Wrap<Node>> wrap_;
346 // Stores the full schema name, if the operator is historic
347 // When the operator is deprecated or the name of the operator
348 // is changed, we need to rely on this name
349 // to retrieve old schemas to successfully apply upgraders
350 // for this operator.
351 c10::optional<std::string> historic_schema_name_ = c10::nullopt;
352
353 protected:
354 Node(Graph* graph_, NodeKind kind_); // defined after graph
355 public:
356 // Each Node but Return/Param Nodes are associated with exactly one
357 // place in the Node list of the Graph. The Graph itself is a circular
358 // doubly-linked list. The Return Node is used as the sentinel for the
359 // "beginning"/"end" of the list. This means that you can tell when
360 // you've traversed the entire list without means worrying about null
361 // pointers. `next_in_graph[0]` is the pointer to the next Node, while
362 // `next_in_graph[1]` is the pointer to the previous Node. The
363 // linked list is implemented as an array to allow the same iterator
364 // class for forward and reversed Node lists. Taken together, this
365 // list also represents a topological sort of the Nodes in the Graph.
366 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays)
367 Node* next_in_graph[2] = {nullptr, nullptr};
368
369 std::shared_ptr<Wrap<Node>> wrap() {
370 if (!wrap_) {
371 wrap_ = std::make_shared<Wrap<Node>>(this);
372 }
373 return wrap_;
374 }
375
376 const c10::optional<std::string> getHistoricSchemaName() {
377 return historic_schema_name_;
378 }
379
380 void setHistoricSchemaName(const std::string& name) {
381 historic_schema_name_ = name;
382 }
383
384 Node*& next() {
385 return next_in_graph[kNextDirection];
386 }
387 Node*& prev() {
388 return next_in_graph[kPrevDirection];
389 }
390 Node* const& next() const {
391 return next_in_graph[kNextDirection];
392 }
393 Node* const& prev() const {
394 return next_in_graph[kPrevDirection];
395 }
396
397 NodeKind kind() const {
398 return kind_;
399 }
400 Node* setSourceRange(SourceRange r) {
401 source_range_ = std::move(r);
402 return this;
403 }
404 SourceRange sourceRange() const;
405
406 /**
407 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
408 * Check #87343 for details.
409 */
410 Graph* owningGraph() {
411 return graph_;
412 }
413 const Graph* owningGraph() const {
414 return graph_;
415 }
416 Block* owningBlock() {
417 return owning_block_;
418 }
419 const Block* owningBlock() const {
420 return owning_block_;
421 }
422 ScopePtr scope() {
423 return scope_;
424 }
425 void setScope(ScopePtr scope) {
426 scope_ = std::move(scope);
427 }
428 std::string scopeName() const {
429 if (!scope_) {
430 return "";
431 }
432 return scope_->namesFromRoot();
433 }
434
435 // Copies the source range, scope and callstack from another node.
436 Node* copyMetadata(Node* from) {
437 this->setSourceRange(from->sourceRange());
438 this->setScope(from->scope());
439 if (auto cs = from->callstack()) {
440 this->setCallStack(*cs);
441 }
442 return this;
443 }
444
445 c10::optional<InlinedCallStackPtr> callstack() const {
446 return callstack_;
447 }
448 void setCallStack(InlinedCallStackPtr cs) {
449 callstack_ = cs;
450 }
451
452 // NB: This returns an ArrayRef; that means that it will
453 // get invalidated if you resize inputs (e.g., using addInput)
454 // We can't return a std::vector<Node*>& because there's no
455 // way to soundly cast to std::vector<const Node*> (an insane
456 // implementation of std::vector could make this representationally
457 // different.)
458 at::ArrayRef<Value*> inputs() {
459 return inputs_;
460 }
461 at::ArrayRef<const Value*> inputs() const {
462 // Vectors are not convertible in const-ness of elements, but
463 // raw pointers are.
464 return {inputs_.data(), inputs_.size()};
465 }
466 // NB: This returns an ArrayRef; that means that it will
467 // get invalidated if you resize inputs (e.g., using addInput)
468 // We can't return a std::vector<Node*>& because there's no
469 // way to soundly cast to std::vector<const Node*> (an insane
470 // implementation of std::vector could make this representationally
471 // different.)
472 at::ArrayRef<Value*> outputs() {
473 return outputs_;
474 }
475 at::ArrayRef<const Value*> outputs() const {
476 // Vectors are not convertible in const-ness of elements, but
477 // raw pointers are.
478 return {outputs_.data(), outputs_.size()};
479 }
480 Value* output(size_t i) const {
481 return outputs_.at(i);
482 }
483 bool hasUses() const {
484 for (auto o : outputs()) {
485 if (!o->uses().empty()) {
486 return true;
487 }
488 }
489 return false;
490 }
491
492 void replaceAllUsesWith(Node* n);
493
494 // replaces `this` with a new node with the same inputs and outputs
495 // but a new node symbol. does not destroy `this`
496 Node* replaceWithNewSymbol(Symbol new_symbol);
497
498 // Checks if this node is dominated by `dominator` which means that
499 // `dominator` will always be executed before `this` and `dominator`
500 // is in scope of `this.
501 bool isDominatedBy(const Node* dominator) const;
502
503 // lots of things like chunk have a single input or single output, so we have
504 // a helper to make accessing it easier
505 Value* input() {
506 AT_ASSERT(inputs_.size() == 1);
507 return inputs_.at(0);
508 }
509 Value* output() {
510 AT_ASSERT(outputs_.size() == 1);
511 return outputs_.at(0);
512 }
513 const Value* output() const {
514 AT_ASSERT(outputs_.size() == 1);
515 return outputs_.at(0);
516 }
517 const Value* input() const {
518 AT_ASSERT(inputs_.size() == 1);
519 return inputs_.at(0);
520 }
521 // Access a particular input. This is a checked index.
522 Value* input(size_t i) const {
523 return inputs_.at(i);
524 }
525
526 bool hasNamedInput(const std::string& unqualName) const;
527 Value* namedInput(const std::string& unqualName) const;
528 Value* namedInput(Symbol name) const;
529
530 c10::optional<IValue> get(Symbol name) const;
531
532 template <typename T>
533 c10::optional<T> get(Symbol name) const {
534 if (auto v = get(name)) {
535 return v->template to<T>();
536 }
537 return c10::nullopt;
538 }
539
540 // Returns true if the value of input name is statically known
541 bool is_constant(Symbol name) const {
542 return static_cast<bool>(get(name));
543 }
544 bool mustBeNone() const;
545
546 bool isNondeterministic() const;
547 bool hasSideEffects() const;
548
549 // instructions lowered by the interpreter and not run in the optimized graph
550 bool notExecutedOp() const {
551 return kind_ == prim::Constant || kind_ == prim::profile ||
552 kind_ == prim::profile_ivalue;
553 }
554
555 // Graphs
556
557 // Note [Topological invariant]
558 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
559 // We always maintain an up-to-date topological ordering of all nodes via
560 // the next()/prev() links. All transformations to graphs must preserve
561 // this topological ordering: for example, it is only valid to 'addInput'
562 // with an input which is topologically before the current node.
563 //
564 // Usually, it is obvious whether or not topological order is maintained;
565 // for example, if you are adding nodes to the end of the topsort, it's
566 // impossible for them to refer to inputs that are not in the topsort.
567 // If it is not obvious, please comment accordingly.
568
569 // Add 'node' as an input to 'this' at the end of existing
570 // arguments. Returns the added node for ease of chaining.
571 //
572 // Given: %3 = f(%1, %2)
573 // Execute: %3.addInput(%4)
574 // Result: %3 = f(%1, %2, %4)
575 Value* addInput(Value* value);
576
577 // Add 'value' as an input to 'this' at the specified position in the
578 // arguments. Returns the added value for ease of chaining.
579 Value* insertInput(size_t i, Value* value);
580
581 // Replace the input of 'this' at position 'i' with
582 // 'newValue', returning the old node.
583 //
584 // Given: %3 = f(%1, %2)
585 // Execute: %3.replaceInput(1, %4)
586 // Result: %3 = f(%1, %4)
587 Value* replaceInput(size_t i, Value* newValue);
588
589 // Replace all occurrences of 'from' in the inputs of this
590 // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
591 //
592 // Given: %3 = f(%1, %2, %1)
593 // Execute: %3.replaceInputWith(%1, %4)
594 // Result: %3 = f(%4, %2, %4)
595 void replaceInputWith(Value* from, Value* to);
596
597 Value* addOutput();
598
599 Value* insertOutput(size_t i);
600
601 void eraseOutput(size_t i);
602
603 Block* addBlock();
604 void eraseBlock(size_t i);
605
606 // Each Node can have a list of subblocks. These are used to define structured
607 // nested control flow operators such as If and Loop.
608 // The meaning of a block is specific to the kind of node it is in, but
609 // all blocks share these semantics:
610 // * Nested lexical scoping: If a node 'Parent' has a subblock which contains
611 // a node 'Child', Child can use any value that was in scope for the Parent
612 // node in addition to any values defined before 'Child' in the subblock.
613 // * The list of inputs to the block are in scope for the duration of the
614 // block
615 // * the outputs of the Parent node are not in scope for the subblocks
616 // Typically the inputs to a block that represents control flow act as
617 // as the equivalents phi-nodes in standard SSA form,
618 // defining a new Value to represent any term that has multiple
619 // definitions depending on how control flowed. Outputs of the node containing
620 // control flow serve a similiar purpose defining new values for variables
621 // that would have different definitions depending on which way control
622 // flowed.
623
624 at::ArrayRef<Block*> blocks() {
625 return blocks_;
626 }
627 at::ArrayRef<const Block*> blocks() const {
628 // Vectors are not convertible in const-ness of elements, but
629 // raw pointers are.
630 return {blocks_.data(), blocks_.size()};
631 }
632
633 // Is 'this' before 'n' in the topological order?
634 bool isBefore(const Node* n) const;
635
636 // Is 'this' after 'n' in the topological order?
637 bool isAfter(const Node* n) const;
638
639 // Insert unattached 'this' node before 'n' in the topological order.
640 // Returns this (for chaining).
641 //
642 // Given: %3 = f(%1, %2)
643 // %4 = g(%3)
644 // and unattached: %5 = h(%1)
645 // Execute: %5.insertBefore(%4)
646 // Result: %3 = f(%1, %2)
647 // %5 = h(%1)
648 // %4 = g(%3)
649 Node* insertBefore(Node* n);
650
651 // Insert unattached 'this' node after 'n' in the topological order.
652 // Returns this (for chaining).
653 //
654 // Given: %3 = f(%1, %2)
655 // %4 = g(%3)
656 // and unattached: %5 = h(%1)
657 // Execute: %5.insertAfter(%4)
658 // Result: %3 = f(%1, %2)
659 // %4 = g(%3)
660 // %5 = h(%1)
661 Node* insertAfter(Node* n);
662
663 // Move 'this' (already in the graph) after 'n' in the topological order.
664 //
665 // NOTE: Does not check that value dependencies are preserved, see
666 // AliasDb::moveAfterTopologicallyValid
667 //
668 // Given: %2 = f(%1)
669 // %3 = g(%1)
670 // Execute: %2.moveAfter(%3)
671 // Result: %3 = g(%1)
672 // %2 = f(%1)
673 //
674 void moveAfter(Node* n);
675
676 // Move a node 'n' (already in the graph) before 'this' in the topological
677 // order.
678 //
679 // NOTE: Does not check that value dependencies are preserved, see
680 // AliasDb::moveBeforeTopologicallyValid
681 //
682 // Given: %2 = f(%1)
683 // %3 = g(%1)
684 // Execute: %3.moveBefore(%2)
685 // Result: %3 = g(%1)
686 // %2 = f(%1)
687 void moveBefore(Node* n);
688
689 // Remove the input at 'i' from this node.
690 //
691 // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
692 // removeInput.
693 //
694 // Given: %3 = f(%1, %2)
695 // Execute: %3.removeInput(1)
696 // Result: %3 = f(%1)
697 void removeInput(size_t i);
698
699 // Remove all inputs from a node.
700 //
701 // Given: %3 = f(%1, %2)
702 // Execute: %3.removeAllInputs()
703 // Result: %3 = f()
704 void removeAllInputs();
705
706 // Remove all outputs from a node.
707 //
708 // Given: %1, %2 = f()
709 // Execute:removeAllInputs()
710 // Result: = f()
711 void removeAllOutputs();
712
713 // Rearrange the ordering of inputs or outputs of a node
714 // Given: %3 = f(%1, %2)
715 // Execute: %3.permuteInputs({1, 0})
716 // Result: %3 = f(%2, %1)
717 // Each index must appear exactly once
718 void permuteInputs(const std::vector<size_t>& new_inputs);
719 void permuteOutputs(const std::vector<size_t>& new_inputs);
720
721 // iterators of the node list starting at this node
722 // useful for resuming a search starting at this node
723 inline graph_node_list_iterator iterator() {
724 return {this, 0};
725 }
726 inline graph_node_list_iterator reverseIterator() {
727 return iterator().reverse();
728 }
729 inline const_graph_node_list_iterator iterator() const {
730 return {this, 0};
731 }
732 inline const_graph_node_list_iterator reverseIterator() const {
733 return iterator().reverse();
734 }
735
736 // Remove 'this' from the instruction list and deallocate it.
737 //
738 // Invariant: no outputs of 'this' may have any uses.
739 //
740 // Given: %2 = f(%1)
741 // %3 = g(%1)
742 // Execute: %2.destroy()
743 // Result: %3 = g(%1)
744 void destroy();
745
746 // Dynamically cast this node to the subclass indicated by the
747 // template variable, returning nullptr if the cast is invalid..
748 //
749 // Example usage: if(auto s = n.cast<Select>()) { ... }
750 template <typename T>
751 T* cast() {
752 if (T::Kind == kind()) {
753 return static_cast<T*>(this);
754 }
755 return nullptr;
756 }
757 template <typename T>
758 const T* cast() const {
759 if (T::Kind == kind()) {
760 return static_cast<const T*>(this);
761 }
762 return nullptr;
763 }
764
765 template <typename T>
766 T* expect() {
767 TORCH_CHECK(
768 T::Kind == kind(),
769 "expected a ",
770 T::Kind.toDisplayString(),
771 " but found a ",
772 kind().toDisplayString());
773 return static_cast<T*>(this);
774 }
775
776 bool matches(const FunctionSchema& schema) const;
777
778 // XXX: this function is meant to be used with string literals only!
779 bool matches(
780 const char* signature_literal,
781 at::ArrayRef<Symbol> const_inputs = {}) const;
782
783 bool isMemberOf(const OperatorSet& os) const;
784 template <typename T>
785 bool isMemberOf(const OperatorMap<T>& om) const {
786 auto it = om.map.find(kind());
787 if (it == om.map.end()) {
788 return false;
789 }
790 for (auto& op : it->second) {
791 if (matches(op.first->schema())) {
792 return true;
793 }
794 }
795 return false;
796 }
797
798 const FunctionSchema& schema() const;
799 const FunctionSchema* maybeSchema() const;
800 const Operator& getOperator() const;
801 Operation getOperation() const;
802
803 const Operator* maybeOperator() const;
804
805 void dump() const;
806
807 std::ostream& print(
808 std::ostream& out,
809 size_t level,
810 std::vector<const Node*>* groups,
811 bool print_source_locations = true,
812 bool print_attributes = true,
813 bool print_scopes = true,
814 bool print_body = true) const;
815
816 virtual ~Node() {
817 if (wrap_) {
818 wrap_->clear();
819 }
820 }
821
822 // Methods for accessing attributes
823 Node* copyAttributes(const Node& rhs) {
824 values_.clear();
825 for (const AVPtr& i : rhs.values_) {
826 values_.push_back(i->clone());
827 }
828 return this;
829 }
830 bool hasAttribute(Symbol name) const {
831 AT_ASSERT(name.is_attr());
832 return findAttr(name, false) != values_.end();
833 }
834 bool hasAttributeS(const std::string& name) const {
835 return hasAttribute(Symbol::attr(name));
836 }
837 AttributeKind kindOf(Symbol name) const {
838 AT_ASSERT(name.is_attr());
839 return (*findAttr(name, true))->kind();
840 }
841 AttributeKind kindOfS(const std::string& name) const {
842 return kindOf(Symbol::attr(name));
843 }
844 Node* removeAttribute(Symbol name) {
845 AT_ASSERT(name.is_attr());
846 values_.erase(findAttr(name, true));
847 return this;
848 }
849 Node* removeAttributeS(const std::string& name) {
850 return removeAttribute(Symbol::attr(name));
851 }
852 bool hasAttributes() const {
853 return !values_.empty();
854 }
855 size_t numAttributes() const {
856 return values_.size();
857 }
858 // The names are returned in order, since name actually is the index.
859 std::vector<Symbol> attributeNames() const {
860 std::vector<Symbol> names;
861 names.reserve(values_.size());
862 for (const AVPtr& a : values_) {
863 names.push_back(a->name);
864 }
865 return names;
866 }
867 std::vector<const char*> attributeNamesS() const {
868 std::vector<const char*> names;
869 names.reserve(values_.size());
870 for (const AVPtr& a : values_) {
871 names.push_back(a->name.toUnqualString());
872 }
873 return names;
874 }
875
876#define CREATE_ACCESSOR(Kind, method) \
877 Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
878 return setAttr<Kind##Attr>( \
879 name, std::forward<Kind##Attr::ConstructorType>(v)); \
880 } \
881 const Kind##Attr::ValueType& method(Symbol name) const { \
882 return getAttr<Kind##Attr>(name); \
883 }
884
885 CREATE_ACCESSOR(Float, f)
886 CREATE_ACCESSOR(Complex, c)
887 CREATE_ACCESSOR(Floats, fs)
888 CREATE_ACCESSOR(ComplexVals, cs)
889 CREATE_ACCESSOR(String, s)
890 CREATE_ACCESSOR(Strings, ss)
891 CREATE_ACCESSOR(Int, i)
892 CREATE_ACCESSOR(Ints, is)
893 CREATE_ACCESSOR(Graph, g)
894 CREATE_ACCESSOR(Graphs, gs)
895 CREATE_ACCESSOR(Type, ty)
896 CREATE_ACCESSOR(Types, tys)
897 CREATE_ACCESSOR(IValue, ival)
898
899#undef CREATE_ACCESSOR
900
901 // Our Graphs are not very const-correct, so we need to allow returning
902 // non-const references too
903 GraphAttr::ValueType& g(Symbol name) {
904 return getAttr<GraphAttr>(name);
905 }
906
907 // does not use CREATE_ACCESSOR because we need additional asserts
908 Node* t_(Symbol name, TensorAttr::ConstructorType v) {
909 return setAttr<TensorAttr>(
910 name, std::forward<TensorAttr::ConstructorType>(v));
911 }
912 const TensorAttr::ValueType& t(Symbol name) const {
913 return getAttr<TensorAttr>(name);
914 }
915
916 Node* ts_(Symbol name, TensorsAttr::ConstructorType v) {
917 return setAttr<TensorsAttr>(
918 name, std::forward<TensorsAttr::ConstructorType>(v));
919 }
920 const TensorsAttr::ValueType& ts(Symbol name) const {
921 return getAttr<TensorsAttr>(name);
922 }
923
924 Block* findCommonAncestorBlockWith(Node* n);
925
926 size_t blocksFromGraphBlock();
927
928 private:
929 void printAttrValue(std::ostream& out, const Symbol& name) const;
930 void printAttributes(std::ostream& out, bool ignore_subgraph) const;
931
932 template <typename T>
933 Node* setAttr(Symbol name, typename T::ConstructorType v) {
934 AT_ASSERT(name.is_attr());
935 auto it = findAttr(name, false);
936 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
937 // NOLINTNEXTLINE(bugprone-branch-clone)
938 if (it == values_.end()) {
939 values_.push_back(std::move(nv));
940 } else {
941 *it = std::move(nv);
942 }
943 return this;
944 }
945 template <typename T>
946 typename T::ValueType& getAttr(Symbol name) const {
947 AT_ASSERT(name.is_attr());
948 auto it = findAttr(name, true);
949 auto* child = dynamic_cast<T*>(it->get());
950 if (child == nullptr) {
951 throw IRAttributeError(name, true);
952 }
953 return child->value();
954 }
955 using AVPtr = AttributeValue::Ptr;
956 // NB: For determinism, we use a vector rather than a hash map. This does
957 // mean that lookups are O(n), so you shouldn't use Attributes to store
958 // a big pile of messages.
959 std::vector<AVPtr> values_;
960 std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) {
961 AT_ASSERT(name.is_attr());
962 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
963 return v->name == name;
964 });
965 if (required && it == values_.end()) {
966 throw IRAttributeError(name, false);
967 }
968 AT_ASSERT(!required || it != values_.end());
969 return it;
970 }
971 std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required)
972 const {
973 AT_ASSERT(name.is_attr());
974 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) {
975 return v->name == name;
976 });
977 if (required && it == values_.end()) {
978 throw IRAttributeError(name, false);
979 }
980 AT_ASSERT(!required || it != values_.end());
981 return it;
982 }
983
984 enum class MoveSide { BEFORE, AFTER };
985 bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
986
987 std::pair<Value*, const Argument&> findInput(Symbol name);
988 // Lookup iterator in use list of _input i_ that corresponds to its use of
989 // _this_
990 use_list::iterator findUseForInput(size_t i);
991
992 // remove the use of input i, this sets input i to nullptr, but
993 // is only used internally to Node before setting it to a new value
994 // or erasing the entry from the list.
995 Value* dropInput(size_t i);
996
997 bool inBlockList() const {
998 if (next() == nullptr) {
999 AT_ASSERT(prev() == nullptr);
1000 }
1001 return next() != nullptr;
1002 }
1003
1004 void removeFromList();
1005 void lint() const;
1006
1007 void assignTopoPosition();
1008
1009 protected:
1010 // subclasses must override
1011 // this function is used by createClone to initialize a new version
1012 // of a node in another graph. It should allocate a new instance of the same
1013 // concrete type as 'this', but in graph 'g' which might be different
1014 // than graph_
1015 virtual Node* allocNewInstance(Graph* g) {
1016 return new Node(g, kind());
1017 }
1018 // create a copy of all properties of Node s into this.
1019 // subclasses should extend if they have additional information to copy.
1020 // 'this' will be allocated with s->allocNewInstance(g) so it should have
1021 // the same concrete type as 's'
1022 virtual void cloneFrom(Node* s);
1023};
1024
1025struct Block {
1026 friend struct Node;
1027 friend struct Graph;
1028
1029 AT_DISALLOW_COPY_AND_ASSIGN(Block);
1030 TORCH_API Block(Graph* graph_, Node* node_);
1031
1032 at::ArrayRef<Value*> inputs() {
1033 return input_->outputs();
1034 }
1035 at::ArrayRef<const Value*> inputs() const {
1036 const auto& inputs = input_->outputs();
1037 return {inputs.data(), inputs.size()};
1038 }
1039 at::ArrayRef<Value*> outputs() {
1040 return output_->inputs();
1041 }
1042 at::ArrayRef<const Value*> outputs() const {
1043 return static_cast<const Node*>(output_)->inputs();
1044 }
1045 graph_node_list nodes() {
1046 return {input_, kNextDirection};
1047 }
1048 const_graph_node_list nodes() const {
1049 return {input_, kNextDirection};
1050 }
1051 Node* return_node() {
1052 return output_;
1053 }
1054 const Node* return_node() const {
1055 return output_;
1056 }
1057 Node* param_node() {
1058 return input_;
1059 }
1060 const Node* param_node() const {
1061 return input_;
1062 }
1063 /**
1064 * @warning NEVER pass raw pointer of smart pointer managed Graph to Python.
1065 * Check #87343 for details.
1066 */
1067 Graph* owningGraph() {
1068 return graph_;
1069 }
1070 const Graph* owningGraph() const {
1071 return graph_;
1072 }
1073 Node* owningNode() {
1074 return owning_node_;
1075 }
1076 const Node* owningNode() const {
1077 return owning_node_;
1078 }
1079
1080 Value* addInput(const std::string& name = "") {
1081 Value* v = input_->addOutput();
1082 v->setDebugName(name);
1083 return v;
1084 }
1085 Value* insertInput(size_t i, const std::string& name = "") {
1086 Value* v = input_->insertOutput(i);
1087 v->setDebugName(name);
1088 return v;
1089 }
1090 void eraseInput(size_t i) {
1091 input_->eraseOutput(i);
1092 }
1093 void removeAllInputs() {
1094 input_->removeAllOutputs();
1095 }
1096 size_t registerOutput(Value* v) {
1097 output_->addInput(v);
1098 return outputs().size() - 1;
1099 }
1100 size_t insertOutput(size_t i, Value* n) {
1101 output_->insertInput(i, n);
1102 return i;
1103 }
1104 void eraseOutput(size_t i) {
1105 output_->removeInput(i);
1106 }
1107 void removeAllOutputs() {
1108 output_->removeAllInputs();
1109 }
1110
1111 void replaceOutput(size_t i, Value* n) {
1112 output_->replaceInput(i, n);
1113 }
1114 void permuteOutputs(const std::vector<size_t>& new_inputs) {
1115 output_->permuteInputs(new_inputs);
1116 }
1117 void permuteInputs(const std::vector<size_t>& new_inputs) {
1118 input_->permuteOutputs(new_inputs);
1119 }
1120
1121 Node* appendNode(Node* n) {
1122 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1123 n->insertBefore(output_);
1124 return n;
1125 }
1126 Node* prependNode(Node* n) {
1127 AT_ASSERT(n->graph_ == graph_ && !n->inBlockList());
1128 n->insertAfter(input_);
1129 return n;
1130 }
1131
1132 // clone all inputs, nodes, and outputs from src and append them
1133 // to the inputs, nodes, and outputs of this block
1134 // value_map is used whenever a node in src references a free variable
1135 // in src to look up its corresponding value
1136 TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map);
1137 TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1138
1139 TORCH_API std::shared_ptr<Wrap<Block>> wrap() {
1140 if (!wrap_) {
1141 wrap_ = std::make_shared<Wrap<Block>>(this);
1142 }
1143 return wrap_;
1144 }
1145
1146 virtual ~Block() {
1147 if (wrap_) {
1148 wrap_->clear();
1149 }
1150 }
1151
1152 void clear() {
1153 removeAllOutputs();
1154 for (auto it = nodes().rbegin(); it != nodes().rend(); it++) {
1155 it.destroyCurrent();
1156 }
1157 removeAllInputs();
1158 }
1159
1160 private:
1161 void reIndexTopology();
1162
1163 // get rid of all nodes
1164 // destroys in reverse order so that uses internal to this block
1165 // do not have to be removed before you can destroy the block
1166 void destroy();
1167
1168 Graph* const graph_;
1169 // holds outputs in a way that can be reflected
1170 // as a Use object
1171 // also used as the beginning/end of the circular node list to avoid
1172 // having corner cases where the list is empty.
1173 Node* const output_;
1174 Node* const input_;
1175 Node* const
1176 owning_node_; // either the node that has this block or nullptr for root
1177 // a managing wrapper for Python to allow invalidation
1178 std::shared_ptr<Wrap<Block>> wrap_;
1179};
1180
1181struct Graph : std::enable_shared_from_this<Graph> {
1182 AT_DISALLOW_COPY_AND_ASSIGN(Graph);
1183 friend struct Node;
1184 friend struct Value;
1185 friend struct Block;
1186
1187 private:
1188 // only used to keep track of allocated nodes
1189 // actual representation of Graph is done with
1190 // inputs, outputs, nodes
1191
1192 std::unordered_set<const Node*> all_nodes;
1193 std::unordered_set<const Value*> all_values;
1194 std::unordered_set<const Block*> all_blocks;
1195 size_t next_unique_;
1196
1197 std::unordered_map<std::string, Value*> unique_names_;
1198 // name_base_suffix tracks largest suffix currently used by all names sharing
1199 // same name_base. Key of this map is name_base, value is largest suffix
1200 // numeric value.
1201 std::unordered_map<std::string, size_t> name_base_suffix_;
1202
1203 ScopePtr current_scope_;
1204
1205 Block* const block_;
1206 // when insertNode() is called, the node is inserted before this node
1207 // by default this is set to append to the top level block
1208 Node* insert_before_;
1209
1210 c10::optional<size_t> op_version_;
1211
1212 public:
1213 Graph(ScopePtr scope_root = c10::make_intrusive<Scope>())
1214 : next_unique_(0),
1215 current_scope_(std::move(scope_root)),
1216 block_(new Block(this, nullptr)),
1217 insert_before_(return_node()) {}
1218
1219 at::ArrayRef<Value*> inputs() {
1220 return block_->inputs();
1221 }
1222 at::ArrayRef<const Value*> inputs() const {
1223 const Block& block = *block_;
1224 return block.inputs();
1225 }
1226 at::ArrayRef<Value*> outputs() {
1227 return block_->outputs();
1228 }
1229 at::ArrayRef<const Value*> outputs() const {
1230 const Block& block = *block_;
1231 return block.outputs();
1232 }
1233 graph_node_list nodes() {
1234 return block_->nodes();
1235 }
1236 const_graph_node_list nodes() const {
1237 const Block& block = *block_;
1238 return block.nodes();
1239 }
1240 Node* param_node() {
1241 return block_->param_node();
1242 }
1243 const Node* param_node() const {
1244 return block_->param_node();
1245 }
1246 Node* return_node() {
1247 return block_->return_node();
1248 }
1249 const Node* return_node() const {
1250 return block_->return_node();
1251 }
1252 const std::unordered_map<std::string, Value*>& debugNames() const {
1253 return unique_names_;
1254 }
1255
1256 TORCH_API void push_scope(const std::string& scope_name);
1257 TORCH_API void pop_scope();
1258
1259 ScopePtr current_scope() {
1260 return current_scope_;
1261 }
1262
1263 void set_op_version(c10::optional<size_t> version) {
1264 op_version_ = version;
1265 }
1266
1267 c10::optional<size_t> get_op_version() {
1268 return op_version_;
1269 }
1270
1271 void set_current_scope(ScopePtr scope) {
1272 current_scope_ = std::move(scope);
1273 }
1274
1275 Value* addInput(const std::string& name = "") {
1276 return block_->addInput(name);
1277 }
1278 Value* insertInput(size_t i, const std::string& name = "") {
1279 return block_->insertInput(i, name);
1280 }
1281 void eraseInput(size_t i) {
1282 block_->eraseInput(i);
1283 }
1284 size_t registerOutput(Value* n) {
1285 return block_->registerOutput(n);
1286 }
1287 void eraseOutput(size_t i) {
1288 block_->eraseOutput(i);
1289 }
1290
1291 TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1);
1292 TORCH_API Node* create(
1293 NodeKind kind,
1294 ArrayRef<Value*> inputs,
1295 size_t num_outputs = 1);
1296
1297 TORCH_API Node* createNone();
1298 TORCH_API Node* createAutogradZero();
1299 TORCH_API Node* createUninitialized(TypePtr typ);
1300 TORCH_API Node* createWithSubgraph(Symbol kind);
1301 TORCH_API Node* createDifferentiableSubgraph();
1302 TORCH_API Node* createTuple(
1303 at::ArrayRef<Value*> values,
1304 TupleTypePtr optional_named_tuple = nullptr);
1305 TORCH_API Node* createTupleUnpack(Value* v);
1306 TORCH_API Node* createTupleIndex(
1307 Value* tup,
1308 Value* idx,
1309 const TypePtr& output_type);
1310 TORCH_API Node* createTupleSlice(
1311 Value* tup,
1312 int64_t beg,
1313 int64_t step_size,
1314 int64_t num_values);
1315 TORCH_API Node* createEnumName(Value* e);
1316 TORCH_API Node* createEnumValue(Value* e);
1317 TORCH_API Node* createList(
1318 const TypePtr& contained_type,
1319 at::ArrayRef<Value*> values);
1320 TORCH_API Node* createListUnpack(Value* v, size_t size);
1321 TORCH_API Node* createDict(
1322 const TypePtr& key_type,
1323 const TypePtr& value_type,
1324 at::ArrayRef<Value*> keys,
1325 at::ArrayRef<Value*> values);
1326 TORCH_API Node* createNumToTensor(Value* value);
1327 TORCH_API Node* createObject(const ClassTypePtr& type);
1328 TORCH_API Node* createSetAttr(
1329 Value* obj,
1330 const std::string& field,
1331 Value* newValue);
1332 TORCH_API Node* createGetAttr(Value* obj, const std::string& field);
1333 Value* insertGetAttr(Value* obj, const std::string& field) {
1334 return insertNode(createGetAttr(obj, field))->output();
1335 }
1336 TORCH_API Node* createStore(const std::string& name, Value* v);
1337 TORCH_API Node* createLoad(const std::string& name, const TypePtr& type);
1338 TORCH_API Node* createIsInstance(Value* v, at::ArrayRef<TypePtr> types);
1339
1340 TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type);
1341
1342 // Insert a ToList operator with argument \p v and output type \p type.
1343 // \returns the output of the operation.
1344 TORCH_API Value* insertToList(Value* v, TypePtr type);
1345
1346 TORCH_API Value* insertFunctionCall(
1347 Function* callee,
1348 const MatchedSchema& matched);
1349 TORCH_API Value* insertMethodCall(
1350 std::string method_name,
1351 const MatchedSchema& matched);
1352
1353 // Note: defined in python_ir.cpp and can be used only in python extension
1354 Node* createPythonOp(
1355 THPObjectPtr&& pyobj,
1356 const std::string& cconv,
1357 pyobj_list&& scalar_args);
1358 // clone n, making a new node in _this_ graph.
1359 // use value_map to translate inputs of n to inputs of the cloned node
1360 // if copy_blocks is false, it will not recursively clone the nested blocks
1361 // this node contains.
1362 TORCH_API Node* createClone(
1363 Node* n,
1364 const std::function<Value*(Value*)>& value_map,
1365 bool copy_blocks = true);
1366
1367 // Insert constant IValue into the graph.
1368 TORCH_API Value* insertConstant(
1369 const IValue& val,
1370 c10::optional<SourceRange> loc = c10::nullopt,
1371 c10::optional<ScopePtr> scope = c10::nullopt);
1372
1373 // Schema-driven insert:
1374 // This inserts a node into the graph with inputs determined from args and
1375 // kwargs using Python argument matching rules, and checks that the op matches
1376 // a known schema.
1377 //
1378 // If this node successfully completes, it guarentees the node
1379 // is a correctly-formed invocation of opname
1380 TORCH_API Value* insert(
1381 Symbol opname,
1382 at::ArrayRef<NamedValue> args,
1383 at::ArrayRef<NamedValue> kwargs = {},
1384 const c10::optional<SourceRange>& range = {});
1385
1386 Node* appendNode(Node* n) {
1387 return block_->appendNode(n);
1388 }
1389
1390 Node* prependNode(Node* n) {
1391 return block_->prependNode(n);
1392 }
1393
1394 // insert before insert_before_ node
1395 // initialized to insert at the end of the top level block
1396 // can be changed with setInsertPoint()
1397 Node* insertNode(Node* n) {
1398 AT_ASSERT(
1399 insert_before_->inBlockList() &&
1400 "insert point node is no longer in a block list");
1401 return n->insertBefore(insert_before_);
1402 }
1403 // set where nodes are inserted to append to the end of this block
1404 void setInsertPoint(Block* b) {
1405 AT_ASSERT(b->owningGraph() == this);
1406 insert_before_ = b->return_node();
1407 }
1408 // set where nodes are inserted to insert _before_ this node
1409 // for implementation simplicity we only support inserting before a node for
1410 // now
1411 void setInsertPoint(Node* n) {
1412 AT_ASSERT(n->owningGraph() == this && n->inBlockList());
1413 insert_before_ = n;
1414 }
1415 Node* insertPoint() {
1416 return insert_before_;
1417 }
1418
1419 // the top level block
1420 Block* block() {
1421 return block_;
1422 }
1423 const Block* block() const {
1424 return block_;
1425 }
1426
1427 // Checks well-formedness and invariants of graph
1428 TORCH_API void lint() const;
1429 // for use in debugger
1430 TORCH_API void dump() const;
1431
1432 TORCH_API ~Graph();
1433
1434 TORCH_API std::string toString(bool print_source_locations = true) const;
1435
1436 TORCH_API std::ostream& print(
1437 std::ostream& out,
1438 bool print_source_locations = true) const;
1439
1440 friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g);
1441
1442 TORCH_API std::shared_ptr<Graph> copy();
1443 TORCH_API std::unique_ptr<Graph> copyUnique();
1444 TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map);
1445
1446 private:
1447 friend TORCH_API void Lint(const AliasDb* db);
1448 TORCH_API void freeNode(Node* n);
1449 TORCH_API void freeValue(Value* v);
1450 TORCH_API void freeBlock(Block* b);
1451 void cloneFrom(Graph& src);
1452};
1453
1454/** \brief An utility class for setting temporary insertion points.
1455 *
1456 * When an object of this class is created, it stores the current insertion
1457 * point, sets the new one, and restores the original insertion point when the
1458 * object is destroyed.
1459 */
1460struct WithInsertPoint {
1461 WithInsertPoint(Node* n) : prev_(n->owningGraph()->insertPoint()) {
1462 n->owningGraph()->setInsertPoint(n);
1463 }
1464 WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {}
1465
1466 ~WithInsertPoint() {
1467 prev_->owningGraph()->setInsertPoint(prev_);
1468 }
1469
1470 private:
1471 Node* prev_;
1472};
1473
1474/** \brief An utility class for setting temporary scopes.
1475 *
1476 * When an object of this class is created, it stores the current scope, sets
1477 * the new one, and restores the original scope when the object is destroyed.
1478 */
1479struct WithCurrentScope {
1480 WithCurrentScope(Graph& g, ScopePtr scope)
1481 : graph_(&g), prev_scope_(g.current_scope()) {
1482 g.set_current_scope(std::move(scope));
1483 }
1484 ~WithCurrentScope() {
1485 graph_->set_current_scope(prev_scope_);
1486 }
1487
1488 private:
1489 Graph* graph_;
1490 ScopePtr prev_scope_;
1491};
1492
1493// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1494inline Value::Value(Node* node_, size_t offset_)
1495 : node_(node_),
1496 offset_(offset_),
1497 unique_(node_->graph_->next_unique_++),
1498 type_(TensorType::get()) {
1499 node_->graph_->all_values.emplace(this);
1500}
1501
1502inline Value* Value::setType(TypePtr type) {
1503 AT_ASSERT(type);
1504 if (auto dyn = type->castRaw<c10::DynamicType>()) {
1505 type = dyn->fallback();
1506 }
1507 type_ = std::move(type);
1508 for (Use& use : uses_) {
1509 use.user->op_ = nullptr;
1510 }
1511 return this;
1512}
1513
1514inline Graph* Value::owningGraph() {
1515 return node()->owningGraph();
1516}
1517
1518inline const Graph* Value::owningGraph() const {
1519 return node()->owningGraph();
1520}
1521
1522/************* All nodes not required to be defined before Graph **************/
1523struct ProfileOp : public Node {
1524 static const Symbol Kind;
1525 ProfileOp(Graph* graph, std::function<void(std::vector<IValue>&)> callback)
1526 : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {}
1527
1528 void cloneFrom(Node* other_) override;
1529 Node* allocNewInstance(Graph* g) override;
1530
1531 const std::function<void(std::vector<IValue>&)>& getCallback() const {
1532 return callback_;
1533 }
1534
1535 void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1536 callback_ = std::move(callback);
1537 }
1538
1539 bool hasSeenTensor() const {
1540 return has_seen_tensor_;
1541 }
1542
1543 void setHasSeenTensor(bool has_seen_tensor) {
1544 has_seen_tensor_ = has_seen_tensor;
1545 }
1546
1547 private:
1548 std::function<void(std::vector<IValue>&)> callback_;
1549 bool has_seen_tensor_ = false;
1550};
1551
1552struct TORCH_API ProfileIValueOp : public Node {
1553 static const Symbol Kind;
1554 ProfileIValueOp(
1555 Graph* graph,
1556 std::function<void(std::vector<IValue>&)> callback)
1557 : Node(graph, ::c10::prim::profile_ivalue),
1558 callback_(std::move(callback)) {}
1559
1560 void cloneFrom(Node* other_) override;
1561 Node* allocNewInstance(Graph* g) override;
1562
1563 const std::function<void(std::vector<IValue>&)>& getCallback() const {
1564 return callback_;
1565 }
1566
1567 void setCallback(std::function<void(std::vector<IValue>&)> callback) {
1568 callback_ = callback;
1569 }
1570
1571 private:
1572 std::function<void(std::vector<IValue>&)> callback_;
1573};
1574
1575// execute a Python function, used for Ops we can't optimize but that we want to
1576// optimize around
1577//
1578// Note: actual implementation (ConcretePythonOp) is defined in python_ir.cpp
1579// which is not included in libtorch.so. We still include some bits and pieces
1580// of PythonOp here to enable writing simple passes generically. In general,
1581// python-aware bits need to be moved to the descendant classes.
1582struct TORCH_API PythonOp : public Node {
1583 using Node::Node;
1584
1585 virtual std::string name() const = 0;
1586 virtual void writeScalars(std::ostream& out) const = 0;
1587 void cloneFrom(Node* other_) override = 0;
1588 Node* allocNewInstance(Graph* g) override = 0;
1589 // recover the autograd.Function instance, if this PythonOp's function
1590 // was originally SomeFunction.apply
1591 // used in ONNX for discovering symbolics
1592 virtual c10::optional<THPObjectPtr> autogradFunction() const = 0;
1593
1594 virtual void lint_python() const = 0;
1595};
1596
1597TORCH_API void LintGraph(const std::shared_ptr<Graph>& graph);
1598
1599TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
1600
1601/** Insert graph \p CALLEE into graph \p G using \p INPUTS as input values.
1602 * The insertion happens at the current insertion point.
1603 * Optionally, one can also pass \p VALUE_MAP to get a map between \p CALLEE
1604 * values and their cloned copies in \p G.
1605 */
1606TORCH_API std::vector<Value*> insertGraph(
1607 Graph& g,
1608 Graph& callee,
1609 ArrayRef<Value*> inputs);
1610TORCH_API std::vector<Value*> insertGraph(
1611 Graph& g,
1612 Graph& callee,
1613 ArrayRef<Value*> inputs,
1614 std::unordered_map<Value*, Value*>& value_map);
1615
1616/** Insert function \p CALLEE after node \p TO_REPLACE, remove the node and
1617 * replace all its uses with corresponding outputs of the inserted function.
1618 * This asserts that the number of outputs of the original node and the
1619 * graph are the same.
1620 */
1621TORCH_API std::vector<Value*> inlineCallTo(
1622 Node* to_replace,
1623 GraphFunction* callee,
1624 bool use_graph = true);
1625
1626TORCH_API std::vector<Value*> inlineCallTo(
1627 Node* to_replace,
1628 GraphFunction* callee,
1629 Graph* callee_graph);
1630
1631/** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a
1632 * tuple unpack node and return the resulting values.
1633 */
1634TORCH_API std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs);
1635
1636TORCH_API std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse);
1637TORCH_API std::vector<Node*> findAllNodes(Block& b, Symbol kind, bool recurse);
1638TORCH_API std::vector<Node*> findAllNodes(
1639 at::ArrayRef<Block*> a,
1640 Symbol kind,
1641 bool recurse);
1642
1643struct TORCH_API OperatorSet {
1644 OperatorSet(std::initializer_list<const char*> sig_literals);
1645 std::vector<std::shared_ptr<Operator>> getOps() const;
1646 void insert(std::initializer_list<const char*> sig_literals);
1647
1648 private:
1649 friend struct Node;
1650 std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops;
1651};
1652
1653template <typename T>
1654// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1655struct OperatorMap {
1656 // Type aliasing
1657 using OpMapType = typename std::pair<std::shared_ptr<Operator>, T>;
1658 using ValueType = std::vector<OpMapType>;
1659 using MapType = std::unordered_map<Symbol, ValueType>;
1660
1661 OperatorMap() = default;
1662 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1663 explicit OperatorMap(
1664 std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> init) {
1665 insert(init);
1666 }
1667 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1668 explicit OperatorMap(std::initializer_list<std::pair<const char*, T>> init) {
1669 insert(init);
1670 }
1671
1672 void insert(const std::shared_ptr<Operator>& op, T val) {
1673 // Remove if exists before insert
1674 erase(op);
1675 map[Symbol::fromQualString(op->schema().name())].emplace_back(
1676 std::make_pair(op, val));
1677 }
1678
1679 void insert(const OperatorSet& op_set, T val) {
1680 for (auto& op : op_set.getOps()) {
1681 insert(op, val);
1682 }
1683 }
1684
1685 void insert(
1686 std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) {
1687 for (auto& el : v) {
1688 insert(el.first, el.second);
1689 }
1690 }
1691
1692 void insert(std::initializer_list<std::pair<const char*, T>> v) {
1693 for (auto& el : v) {
1694 insert(getOperatorForLiteral(el.first), el.second);
1695 }
1696 }
1697
1698 void erase(const std::shared_ptr<Operator>& op) {
1699 auto it = map.find(Symbol::fromQualString(op->schema().name()));
1700 if (it == map.end()) {
1701 return;
1702 }
1703 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1704 if (vit->first->schema() == op->schema()) {
1705 it->second.erase(vit);
1706 break;
1707 }
1708 }
1709 if (it->second.size() == 0) {
1710 map.erase(Symbol::fromQualString(op->schema().name()));
1711 }
1712 }
1713
1714 bool contains(const Operator& op) const {
1715 const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1716 if (it == map.end()) {
1717 return false;
1718 }
1719 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1720 if (vit->first->schema() == op.schema()) {
1721 return true;
1722 }
1723 }
1724 return false;
1725 }
1726
1727 bool contains(const Node* n) const {
1728 return n->maybeOperator() && contains(n->getOperator());
1729 }
1730
1731 c10::optional<T> find(const Operator& op) {
1732 const auto it = map.find(Symbol::fromQualString(op.schema().name()));
1733 if (it == map.end()) {
1734 return c10::nullopt;
1735 }
1736 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1737 if (vit->first->schema() == op.schema()) {
1738 return vit->second;
1739 }
1740 }
1741 return c10::nullopt;
1742 }
1743
1744 // TODO: return iterator
1745 std::vector<OpMapType> getAllKeysAndValues() const {
1746 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1747 std::vector<OpMapType> keys_values;
1748 for (auto& symbol_mapping : map) {
1749 auto& vec = symbol_mapping.second;
1750 for (auto& pair : vec) {
1751 keys_values.push_back(pair);
1752 }
1753 }
1754 return keys_values;
1755 }
1756
1757 private:
1758 friend struct Node;
1759 MapType map;
1760};
1761
1762template <typename T>
1763// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
1764struct FunctionSchemaMap {
1765 // Type aliasing
1766 using FuncSchemaMapType = typename std::pair<FunctionSchema, T>;
1767 using ValueType = std::vector<FuncSchemaMapType>;
1768 using MapType = std::unordered_map<Symbol, ValueType>;
1769
1770 FunctionSchemaMap() = default;
1771 void insert(const FunctionSchema& schema, T val) {
1772 // Remove if exists before insert
1773 erase(schema);
1774 map[Symbol::fromQualString(schema.name())].emplace_back(
1775 std::make_pair(schema, val));
1776 }
1777
1778 void erase(const FunctionSchema& schema) {
1779 auto it = map.find(Symbol::fromQualString(schema.name()));
1780 if (it == map.end()) {
1781 return;
1782 }
1783 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1784 if (vit->first == schema) {
1785 it->second.erase(vit);
1786 break;
1787 }
1788 }
1789 if (it->second.size() == 0) {
1790 map.erase(Symbol::fromQualString(schema.name()));
1791 }
1792 }
1793
1794 bool contains(const FunctionSchema& schema) const {
1795 const auto it = map.find(Symbol::fromQualString(schema.name()));
1796 if (it == map.end()) {
1797 return false;
1798 }
1799 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1800 if (vit->first->schema() == schema) {
1801 return true;
1802 }
1803 }
1804 return false;
1805 }
1806
1807 c10::optional<T> find(const FunctionSchema& schema) const {
1808 const auto it = map.find(Symbol::fromQualString(schema.name()));
1809 if (it == map.end()) {
1810 return c10::nullopt;
1811 }
1812 for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) {
1813 if (vit->first == schema) {
1814 return vit->second;
1815 }
1816 }
1817 return c10::nullopt;
1818 }
1819
1820 // TODO: return iterator
1821 std::vector<FuncSchemaMapType> getAllKeysAndValues() const {
1822 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1823 std::vector<FuncSchemaMapType> keys_values;
1824 for (auto& symbol_mapping : map) {
1825 auto& vec = symbol_mapping.second;
1826 for (auto& pair : vec) {
1827 keys_values.push_back(pair);
1828 }
1829 }
1830 return keys_values;
1831 }
1832
1833 private:
1834 friend struct Node;
1835 MapType map;
1836};
1837
1838} // namespace jit
1839} // namespace torch
1840