1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/attributes.h> |
4 | #include <torch/csrc/jit/ir/graph_node_list.h> |
5 | #include <torch/csrc/jit/ir/named_value.h> |
6 | #include <torch/csrc/jit/ir/scope.h> |
7 | #include <torch/csrc/jit/runtime/operator.h> |
8 | |
9 | #include <torch/csrc/Export.h> |
10 | #include <torch/csrc/utils/python_stub.h> |
11 | #include <torch/csrc/utils/schema_info.h> |
12 | |
13 | #include <ATen/Utils.h> |
14 | #include <ATen/core/Tensor.h> |
15 | #include <ATen/core/dynamic_type.h> |
16 | #include <ATen/core/enum_type.h> |
17 | #include <ATen/core/functional.h> |
18 | #include <ATen/core/interned_strings.h> |
19 | #include <ATen/core/ivalue.h> |
20 | #include <ATen/core/jit_type.h> |
21 | #include <c10/util/ArrayRef.h> |
22 | #include <c10/util/Exception.h> |
23 | #include <c10/util/Optional.h> |
24 | |
25 | #include <functional> |
26 | #include <iostream> |
27 | #include <unordered_set> |
28 | #include <vector> |
29 | |
30 | // Forward declare, the real meat is in python_ir.cpp |
31 | template <class T> |
32 | class THPPointer; |
33 | using THPObjectPtr = THPPointer<PyObject>; |
34 | using pyobj_list = std::vector<THPObjectPtr>; |
35 | |
36 | namespace torch { |
37 | namespace jit { |
38 | namespace utils { |
39 | TORCH_API std::string getNodesModuleHierarchy(const Node& n); |
40 | } // namespace utils |
41 | class AliasDb; |
42 | |
43 | using ::c10::Argument; |
44 | using ::c10::FunctionSchema; |
45 | using ::c10::Symbol; |
46 | |
47 | using ::c10::ivalue::Shared; |
48 | |
49 | using ::c10::IValue; |
50 | using ::c10::ivalue::Future; |
51 | |
52 | using ::c10::ivalue::ConstantString; |
53 | |
54 | #define C10_USING(T) using ::c10::T; |
55 | C10_FORALL_TYPES(C10_USING) |
56 | #undef C10_USING |
57 | |
58 | #define C10_USING(T) using ::c10::T##Ptr; |
59 | C10_FORALL_TYPES(C10_USING) |
60 | #undef C10_USING |
61 | |
62 | using ::c10::Type; |
63 | using ::c10::TypeEnv; |
64 | using ::c10::TypePtr; |
65 | |
66 | using ::c10::getTypePtr; |
67 | using ::c10::MatchTypeReturn; |
68 | using ::c10::TypeKind; |
69 | |
70 | using ::c10::fmap; |
71 | |
72 | namespace prim { |
73 | using namespace ::c10::prim; |
74 | } |
75 | namespace attr { |
76 | using namespace ::c10::attr; |
77 | } |
78 | namespace aten { |
79 | using namespace ::c10::aten; |
80 | } |
81 | namespace cuda { |
82 | #if !defined(USE_ROCM) |
83 | using namespace ::c10::cuda; |
84 | #endif |
85 | } // namespace cuda |
86 | |
87 | struct Function; |
88 | struct GraphFunction; |
89 | struct MatchedSchema; |
90 | |
91 | // A Graph represents one "function" of computation. |
92 | // It uses a simple ownership model where the graph owns all the nodes inside |
93 | // it. All references inside the graph are raw pointers. Destroying the Graph |
94 | // will invalidate any pointers to nodes in the graph. |
95 | struct Graph; |
96 | |
97 | // Node is the base class of the IR graph. It represents one computation |
98 | // and dependencies on a list of Values. The "prim-ops", so to speak. |
99 | struct Node; |
100 | |
101 | // A Value represents an input or output to node that is either a |
102 | // Tensor or an opaque Handle object, as determined by type(). |
103 | struct Value; |
104 | |
105 | TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g); |
106 | TORCH_API std::ostream& operator<<(std::ostream& out, const Node& n); |
107 | |
108 | // A list of nodes, with inputs and outputs |
109 | struct Block; |
110 | |
111 | // Each use is represented by this type, see 'Node::uses()' |
112 | // 'user' is the consumer of the value, 'offset' is the index into |
113 | // 'user's input this where the producers will be found. |
114 | struct Use { |
115 | Use(Node* user, size_t offset) : user(user), offset(offset) {} |
116 | Node* user; |
117 | size_t offset; |
118 | |
119 | bool operator==(const Use& b) { |
120 | return user == b.user && offset == b.offset; |
121 | } |
122 | }; |
123 | |
124 | // Note [User node does not uniquely identify use] |
125 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
126 | // A while back, we wrote some code manipulating uses that looked like this: |
127 | // |
128 | // for (auto& use : used_val->uses_) { |
129 | // if (use.user == this_node) { |
130 | // use.offset += 1; |
131 | // break; |
132 | // } |
133 | // } |
134 | // |
135 | // This code is trying to find a particular use (our node's use) to update it. |
136 | // However, it's wrong: there may be *multiple* uses of a value %x in a node, |
137 | // as might be the case in this IR: |
138 | // |
139 | // %y = Add %x %x |
140 | // |
141 | // In this case, there are two uses of %x whose user is the node 'Add %x %x'. |
142 | // So, "use induced by this node" is not a well-formed concept. |
143 | // |
144 | // If you are looking for "use induced by an input", it's best to use |
145 | // findUseForInput() to get it. |
146 | |
147 | // the list types are intentionally simple, but we type-def |
148 | // them here so if we need to change them, refactoring will be easier |
149 | using node_list = std::vector<Node*>; |
150 | using value_list = std::vector<Value*>; |
151 | using use_list = std::vector<Use>; |
152 | template <typename T> |
153 | using ArrayRef = at::ArrayRef<T>; |
154 | using NodeKind = Symbol; |
155 | using topo_position_t = int64_t; |
156 | using ValueSet = std::unordered_set<const Value*>; |
157 | |
158 | struct OperatorSet; |
159 | template <typename T> |
160 | struct OperatorMap; |
161 | |
162 | // This is a wrapper to allow invalidating the Python object |
163 | // safely when the C++ object for a Node/Value/Block is deleted |
164 | // like much of graph, it isn't safe for different threads to |
165 | // access the same graph |
166 | template <typename T> |
167 | struct Wrap { |
168 | explicit Wrap(T* p) : elem(p), clear_cb(nullptr) {} |
169 | void clear() { |
170 | if (clear_cb) { |
171 | clear_cb(elem); |
172 | } |
173 | elem = nullptr; |
174 | } |
175 | T* elem; |
176 | void (*clear_cb)(void*); |
177 | }; |
178 | |
179 | struct Value { |
180 | AT_DISALLOW_COPY_AND_ASSIGN(Value); |
181 | Value(Node* node_, size_t offset_); |
182 | |
183 | private: |
184 | friend struct Node; |
185 | friend struct Graph; |
186 | Node* node_; |
187 | size_t offset_; |
188 | size_t unique_ = 0; // unique id |
189 | use_list uses_; |
190 | std::string unique_name_; |
191 | TypePtr type_; |
192 | // a managing wrapper for Python to allow invalidation |
193 | std::shared_ptr<Wrap<Value>> wrap_; |
194 | |
195 | public: |
196 | Value* setType(TypePtr type); |
197 | TORCH_API void inferTypeFrom(const at::Tensor& output); |
198 | TORCH_API void inferTypeFrom( |
199 | const c10::intrusive_ptr<c10::ivalue::Object>& output); |
200 | const TypePtr& type() const { |
201 | AT_ASSERT(type_ != nullptr); |
202 | return type_; |
203 | } |
204 | bool requires_grad() const { |
205 | return type()->requires_grad(); |
206 | } |
207 | bool isCompleteTensor() const { |
208 | if (auto pt = type()->cast<TensorType>()) { |
209 | return pt->isComplete(); |
210 | } |
211 | return false; |
212 | } |
213 | TORCH_API bool mustBeNone() const; |
214 | TORCH_API bool mustNotBeNone() const; |
215 | size_t unique() const { |
216 | return unique_; |
217 | } |
218 | bool hasDebugName() const { |
219 | return !unique_name_.empty(); |
220 | } |
221 | static bool isValidName(const std::string& name); |
222 | TORCH_API Value* setDebugName(const std::string& name); |
223 | std::string debugName() const { |
224 | if (hasDebugName()) { |
225 | return unique_name_; |
226 | } |
227 | return c10::to_string(unique()); |
228 | } |
229 | TORCH_API std::string debugNameBase() const; |
230 | Node* node() { |
231 | return node_; |
232 | } |
233 | size_t offset() const { |
234 | return offset_; |
235 | } |
236 | void setOffset(size_t offset) { |
237 | offset_ = offset; |
238 | } |
239 | const Node* node() const { |
240 | return node_; |
241 | } |
242 | |
243 | /** |
244 | * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. |
245 | * Check #87343 for details. |
246 | */ |
247 | Graph* owningGraph(); |
248 | const Graph* owningGraph() const; |
249 | // TODO: make this more const correct |
250 | const use_list& uses() const { |
251 | return uses_; |
252 | } |
253 | |
254 | bool hasUses() const { |
255 | return !uses().empty(); |
256 | } |
257 | |
258 | TORCH_API void replaceFirstUseWith(Value* newValue); |
259 | |
260 | // Replaces all uses of this value with 'newValue'. |
261 | // |
262 | // Given: %3 = f(%1, %2) |
263 | // %4 = g(%3) |
264 | // %5 = h(%3, %3) |
265 | // Execute: %3.replaceAllUsesWith(%6) |
266 | // Result: %3 = f(%1, %2) |
267 | // %4 = g(%6) |
268 | // %5 = h(%6, %6) |
269 | TORCH_API void replaceAllUsesWith(Value* newValue); |
270 | |
271 | // Replaces all uses of this value with 'newValue' after 'node'. |
272 | // Given: %3 = f(%1, %2) |
273 | // %4 = g(%3) |
274 | // %5 = inplace_(%3) |
275 | // %6 = h(%3, %3) |
276 | // Execute: %3.replaceAllUsesAfterNodeWith(%5.node(), %5) |
277 | // Result: %3 = f(%1, %2) |
278 | // %4 = g(%3) |
279 | // %5 = inplace_(%3) |
280 | // %6 = h(%5, %5) |
281 | // XXX: does not check scoping legality, consider using |
282 | // replaceAllUsesDominatedByNodeWith |
283 | TORCH_API void replaceAllUsesAfterNodeWith(const Node* node, Value* newValue); |
284 | |
285 | // Replaces all uses of this value with 'newValue' that are dominated by |
286 | // 'node'. Given: |
287 | // x = op(...). |
288 | // if cond: |
289 | // z = foo(..) |
290 | // bar(x) |
291 | // else: |
292 | // print(x) |
293 | // x.replaceAllUsesDominatedByNodeWith(foo, z) would replace bar(x) |
294 | // but not print(x) because print is not dominated by foo. |
295 | // replaceAllUsesAfterNode does not check domination, so in this example |
296 | // it would produce invalid IR. |
297 | TORCH_API void replaceAllUsesDominatedByNodeWith( |
298 | const Node* node, |
299 | Value* newValue); |
300 | |
301 | TORCH_API Value* copyMetadata(Value* from); |
302 | |
303 | TORCH_API std::shared_ptr<Wrap<Value>> wrap() { |
304 | if (!wrap_) { |
305 | wrap_ = std::make_shared<Wrap<Value>>(this); |
306 | } |
307 | return wrap_; |
308 | } |
309 | |
310 | virtual ~Value() { |
311 | if (wrap_) { |
312 | wrap_->clear(); |
313 | } |
314 | } |
315 | }; |
316 | |
317 | struct TORCH_API Node { |
318 | AT_DISALLOW_COPY_AND_ASSIGN(Node); |
319 | friend struct Graph; |
320 | friend struct Block; |
321 | friend struct Value; |
322 | friend graph_node_list; |
323 | friend const_graph_node_list; |
324 | friend graph_node_list_iterator; |
325 | friend const_graph_node_list_iterator; |
326 | |
327 | private: |
328 | const NodeKind kind_; |
329 | std::vector<Value*> inputs_; |
330 | std::vector<Value*> outputs_; |
331 | // subblocks |
332 | std::vector<Block*> blocks_; |
333 | Graph* graph_; |
334 | Block* owning_block_; |
335 | c10::optional<SourceRange> source_range_; |
336 | ScopePtr scope_; |
337 | c10::optional<InlinedCallStackPtr> callstack_; |
338 | // Assumes FunctionSchemas are persistent, so we don't manage their lifetime. |
339 | // This field is effective a cache that's populated on attribute lookups and |
340 | // invalidated every time we perform an operation that could potentially |
341 | // change the schema. note: mutable because schema_ is effectively a cache |
342 | mutable const Operator* op_; |
343 | topo_position_t topo_position_ = 0; |
344 | // a managing wrapper for Python to allow invalidation |
345 | std::shared_ptr<Wrap<Node>> wrap_; |
346 | // Stores the full schema name, if the operator is historic |
347 | // When the operator is deprecated or the name of the operator |
348 | // is changed, we need to rely on this name |
349 | // to retrieve old schemas to successfully apply upgraders |
350 | // for this operator. |
351 | c10::optional<std::string> historic_schema_name_ = c10::nullopt; |
352 | |
353 | protected: |
354 | Node(Graph* graph_, NodeKind kind_); // defined after graph |
355 | public: |
356 | // Each Node but Return/Param Nodes are associated with exactly one |
357 | // place in the Node list of the Graph. The Graph itself is a circular |
358 | // doubly-linked list. The Return Node is used as the sentinel for the |
359 | // "beginning"/"end" of the list. This means that you can tell when |
360 | // you've traversed the entire list without means worrying about null |
361 | // pointers. `next_in_graph[0]` is the pointer to the next Node, while |
362 | // `next_in_graph[1]` is the pointer to the previous Node. The |
363 | // linked list is implemented as an array to allow the same iterator |
364 | // class for forward and reversed Node lists. Taken together, this |
365 | // list also represents a topological sort of the Nodes in the Graph. |
366 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-non-private-member-variables-in-classes,modernize-avoid-c-arrays) |
367 | Node* next_in_graph[2] = {nullptr, nullptr}; |
368 | |
369 | std::shared_ptr<Wrap<Node>> wrap() { |
370 | if (!wrap_) { |
371 | wrap_ = std::make_shared<Wrap<Node>>(this); |
372 | } |
373 | return wrap_; |
374 | } |
375 | |
376 | const c10::optional<std::string> getHistoricSchemaName() { |
377 | return historic_schema_name_; |
378 | } |
379 | |
380 | void setHistoricSchemaName(const std::string& name) { |
381 | historic_schema_name_ = name; |
382 | } |
383 | |
384 | Node*& next() { |
385 | return next_in_graph[kNextDirection]; |
386 | } |
387 | Node*& prev() { |
388 | return next_in_graph[kPrevDirection]; |
389 | } |
390 | Node* const& next() const { |
391 | return next_in_graph[kNextDirection]; |
392 | } |
393 | Node* const& prev() const { |
394 | return next_in_graph[kPrevDirection]; |
395 | } |
396 | |
397 | NodeKind kind() const { |
398 | return kind_; |
399 | } |
400 | Node* setSourceRange(SourceRange r) { |
401 | source_range_ = std::move(r); |
402 | return this; |
403 | } |
404 | SourceRange sourceRange() const; |
405 | |
406 | /** |
407 | * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. |
408 | * Check #87343 for details. |
409 | */ |
410 | Graph* owningGraph() { |
411 | return graph_; |
412 | } |
413 | const Graph* owningGraph() const { |
414 | return graph_; |
415 | } |
416 | Block* owningBlock() { |
417 | return owning_block_; |
418 | } |
419 | const Block* owningBlock() const { |
420 | return owning_block_; |
421 | } |
422 | ScopePtr scope() { |
423 | return scope_; |
424 | } |
425 | void setScope(ScopePtr scope) { |
426 | scope_ = std::move(scope); |
427 | } |
428 | std::string scopeName() const { |
429 | if (!scope_) { |
430 | return "" ; |
431 | } |
432 | return scope_->namesFromRoot(); |
433 | } |
434 | |
435 | // Copies the source range, scope and callstack from another node. |
436 | Node* copyMetadata(Node* from) { |
437 | this->setSourceRange(from->sourceRange()); |
438 | this->setScope(from->scope()); |
439 | if (auto cs = from->callstack()) { |
440 | this->setCallStack(*cs); |
441 | } |
442 | return this; |
443 | } |
444 | |
445 | c10::optional<InlinedCallStackPtr> callstack() const { |
446 | return callstack_; |
447 | } |
448 | void setCallStack(InlinedCallStackPtr cs) { |
449 | callstack_ = cs; |
450 | } |
451 | |
452 | // NB: This returns an ArrayRef; that means that it will |
453 | // get invalidated if you resize inputs (e.g., using addInput) |
454 | // We can't return a std::vector<Node*>& because there's no |
455 | // way to soundly cast to std::vector<const Node*> (an insane |
456 | // implementation of std::vector could make this representationally |
457 | // different.) |
458 | at::ArrayRef<Value*> inputs() { |
459 | return inputs_; |
460 | } |
461 | at::ArrayRef<const Value*> inputs() const { |
462 | // Vectors are not convertible in const-ness of elements, but |
463 | // raw pointers are. |
464 | return {inputs_.data(), inputs_.size()}; |
465 | } |
466 | // NB: This returns an ArrayRef; that means that it will |
467 | // get invalidated if you resize inputs (e.g., using addInput) |
468 | // We can't return a std::vector<Node*>& because there's no |
469 | // way to soundly cast to std::vector<const Node*> (an insane |
470 | // implementation of std::vector could make this representationally |
471 | // different.) |
472 | at::ArrayRef<Value*> outputs() { |
473 | return outputs_; |
474 | } |
475 | at::ArrayRef<const Value*> outputs() const { |
476 | // Vectors are not convertible in const-ness of elements, but |
477 | // raw pointers are. |
478 | return {outputs_.data(), outputs_.size()}; |
479 | } |
480 | Value* output(size_t i) const { |
481 | return outputs_.at(i); |
482 | } |
483 | bool hasUses() const { |
484 | for (auto o : outputs()) { |
485 | if (!o->uses().empty()) { |
486 | return true; |
487 | } |
488 | } |
489 | return false; |
490 | } |
491 | |
492 | void replaceAllUsesWith(Node* n); |
493 | |
494 | // replaces `this` with a new node with the same inputs and outputs |
495 | // but a new node symbol. does not destroy `this` |
496 | Node* replaceWithNewSymbol(Symbol new_symbol); |
497 | |
498 | // Checks if this node is dominated by `dominator` which means that |
499 | // `dominator` will always be executed before `this` and `dominator` |
500 | // is in scope of `this. |
501 | bool isDominatedBy(const Node* dominator) const; |
502 | |
503 | // lots of things like chunk have a single input or single output, so we have |
504 | // a helper to make accessing it easier |
505 | Value* input() { |
506 | AT_ASSERT(inputs_.size() == 1); |
507 | return inputs_.at(0); |
508 | } |
509 | Value* output() { |
510 | AT_ASSERT(outputs_.size() == 1); |
511 | return outputs_.at(0); |
512 | } |
513 | const Value* output() const { |
514 | AT_ASSERT(outputs_.size() == 1); |
515 | return outputs_.at(0); |
516 | } |
517 | const Value* input() const { |
518 | AT_ASSERT(inputs_.size() == 1); |
519 | return inputs_.at(0); |
520 | } |
521 | // Access a particular input. This is a checked index. |
522 | Value* input(size_t i) const { |
523 | return inputs_.at(i); |
524 | } |
525 | |
526 | bool hasNamedInput(const std::string& unqualName) const; |
527 | Value* namedInput(const std::string& unqualName) const; |
528 | Value* namedInput(Symbol name) const; |
529 | |
530 | c10::optional<IValue> get(Symbol name) const; |
531 | |
532 | template <typename T> |
533 | c10::optional<T> get(Symbol name) const { |
534 | if (auto v = get(name)) { |
535 | return v->template to<T>(); |
536 | } |
537 | return c10::nullopt; |
538 | } |
539 | |
540 | // Returns true if the value of input name is statically known |
541 | bool is_constant(Symbol name) const { |
542 | return static_cast<bool>(get(name)); |
543 | } |
544 | bool mustBeNone() const; |
545 | |
546 | bool isNondeterministic() const; |
547 | bool hasSideEffects() const; |
548 | |
549 | // instructions lowered by the interpreter and not run in the optimized graph |
550 | bool notExecutedOp() const { |
551 | return kind_ == prim::Constant || kind_ == prim::profile || |
552 | kind_ == prim::profile_ivalue; |
553 | } |
554 | |
555 | // Graphs |
556 | |
557 | // Note [Topological invariant] |
558 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
559 | // We always maintain an up-to-date topological ordering of all nodes via |
560 | // the next()/prev() links. All transformations to graphs must preserve |
561 | // this topological ordering: for example, it is only valid to 'addInput' |
562 | // with an input which is topologically before the current node. |
563 | // |
564 | // Usually, it is obvious whether or not topological order is maintained; |
565 | // for example, if you are adding nodes to the end of the topsort, it's |
566 | // impossible for them to refer to inputs that are not in the topsort. |
567 | // If it is not obvious, please comment accordingly. |
568 | |
569 | // Add 'node' as an input to 'this' at the end of existing |
570 | // arguments. Returns the added node for ease of chaining. |
571 | // |
572 | // Given: %3 = f(%1, %2) |
573 | // Execute: %3.addInput(%4) |
574 | // Result: %3 = f(%1, %2, %4) |
575 | Value* addInput(Value* value); |
576 | |
577 | // Add 'value' as an input to 'this' at the specified position in the |
578 | // arguments. Returns the added value for ease of chaining. |
579 | Value* insertInput(size_t i, Value* value); |
580 | |
581 | // Replace the input of 'this' at position 'i' with |
582 | // 'newValue', returning the old node. |
583 | // |
584 | // Given: %3 = f(%1, %2) |
585 | // Execute: %3.replaceInput(1, %4) |
586 | // Result: %3 = f(%1, %4) |
587 | Value* replaceInput(size_t i, Value* newValue); |
588 | |
589 | // Replace all occurrences of 'from' in the inputs of this |
590 | // node with 'to'. Corresponds to llvm's replaceUsesOfWith. |
591 | // |
592 | // Given: %3 = f(%1, %2, %1) |
593 | // Execute: %3.replaceInputWith(%1, %4) |
594 | // Result: %3 = f(%4, %2, %4) |
595 | void replaceInputWith(Value* from, Value* to); |
596 | |
597 | Value* addOutput(); |
598 | |
599 | Value* insertOutput(size_t i); |
600 | |
601 | void eraseOutput(size_t i); |
602 | |
603 | Block* addBlock(); |
604 | void eraseBlock(size_t i); |
605 | |
606 | // Each Node can have a list of subblocks. These are used to define structured |
607 | // nested control flow operators such as If and Loop. |
608 | // The meaning of a block is specific to the kind of node it is in, but |
609 | // all blocks share these semantics: |
610 | // * Nested lexical scoping: If a node 'Parent' has a subblock which contains |
611 | // a node 'Child', Child can use any value that was in scope for the Parent |
612 | // node in addition to any values defined before 'Child' in the subblock. |
613 | // * The list of inputs to the block are in scope for the duration of the |
614 | // block |
615 | // * the outputs of the Parent node are not in scope for the subblocks |
616 | // Typically the inputs to a block that represents control flow act as |
617 | // as the equivalents phi-nodes in standard SSA form, |
618 | // defining a new Value to represent any term that has multiple |
619 | // definitions depending on how control flowed. Outputs of the node containing |
620 | // control flow serve a similiar purpose defining new values for variables |
621 | // that would have different definitions depending on which way control |
622 | // flowed. |
623 | |
624 | at::ArrayRef<Block*> blocks() { |
625 | return blocks_; |
626 | } |
627 | at::ArrayRef<const Block*> blocks() const { |
628 | // Vectors are not convertible in const-ness of elements, but |
629 | // raw pointers are. |
630 | return {blocks_.data(), blocks_.size()}; |
631 | } |
632 | |
633 | // Is 'this' before 'n' in the topological order? |
634 | bool isBefore(const Node* n) const; |
635 | |
636 | // Is 'this' after 'n' in the topological order? |
637 | bool isAfter(const Node* n) const; |
638 | |
639 | // Insert unattached 'this' node before 'n' in the topological order. |
640 | // Returns this (for chaining). |
641 | // |
642 | // Given: %3 = f(%1, %2) |
643 | // %4 = g(%3) |
644 | // and unattached: %5 = h(%1) |
645 | // Execute: %5.insertBefore(%4) |
646 | // Result: %3 = f(%1, %2) |
647 | // %5 = h(%1) |
648 | // %4 = g(%3) |
649 | Node* insertBefore(Node* n); |
650 | |
651 | // Insert unattached 'this' node after 'n' in the topological order. |
652 | // Returns this (for chaining). |
653 | // |
654 | // Given: %3 = f(%1, %2) |
655 | // %4 = g(%3) |
656 | // and unattached: %5 = h(%1) |
657 | // Execute: %5.insertAfter(%4) |
658 | // Result: %3 = f(%1, %2) |
659 | // %4 = g(%3) |
660 | // %5 = h(%1) |
661 | Node* insertAfter(Node* n); |
662 | |
663 | // Move 'this' (already in the graph) after 'n' in the topological order. |
664 | // |
665 | // NOTE: Does not check that value dependencies are preserved, see |
666 | // AliasDb::moveAfterTopologicallyValid |
667 | // |
668 | // Given: %2 = f(%1) |
669 | // %3 = g(%1) |
670 | // Execute: %2.moveAfter(%3) |
671 | // Result: %3 = g(%1) |
672 | // %2 = f(%1) |
673 | // |
674 | void moveAfter(Node* n); |
675 | |
676 | // Move a node 'n' (already in the graph) before 'this' in the topological |
677 | // order. |
678 | // |
679 | // NOTE: Does not check that value dependencies are preserved, see |
680 | // AliasDb::moveBeforeTopologicallyValid |
681 | // |
682 | // Given: %2 = f(%1) |
683 | // %3 = g(%1) |
684 | // Execute: %3.moveBefore(%2) |
685 | // Result: %3 = g(%1) |
686 | // %2 = f(%1) |
687 | void moveBefore(Node* n); |
688 | |
689 | // Remove the input at 'i' from this node. |
690 | // |
691 | // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling |
692 | // removeInput. |
693 | // |
694 | // Given: %3 = f(%1, %2) |
695 | // Execute: %3.removeInput(1) |
696 | // Result: %3 = f(%1) |
697 | void removeInput(size_t i); |
698 | |
699 | // Remove all inputs from a node. |
700 | // |
701 | // Given: %3 = f(%1, %2) |
702 | // Execute: %3.removeAllInputs() |
703 | // Result: %3 = f() |
704 | void removeAllInputs(); |
705 | |
706 | // Remove all outputs from a node. |
707 | // |
708 | // Given: %1, %2 = f() |
709 | // Execute:removeAllInputs() |
710 | // Result: = f() |
711 | void removeAllOutputs(); |
712 | |
713 | // Rearrange the ordering of inputs or outputs of a node |
714 | // Given: %3 = f(%1, %2) |
715 | // Execute: %3.permuteInputs({1, 0}) |
716 | // Result: %3 = f(%2, %1) |
717 | // Each index must appear exactly once |
718 | void permuteInputs(const std::vector<size_t>& new_inputs); |
719 | void permuteOutputs(const std::vector<size_t>& new_inputs); |
720 | |
721 | // iterators of the node list starting at this node |
722 | // useful for resuming a search starting at this node |
723 | inline graph_node_list_iterator iterator() { |
724 | return {this, 0}; |
725 | } |
726 | inline graph_node_list_iterator reverseIterator() { |
727 | return iterator().reverse(); |
728 | } |
729 | inline const_graph_node_list_iterator iterator() const { |
730 | return {this, 0}; |
731 | } |
732 | inline const_graph_node_list_iterator reverseIterator() const { |
733 | return iterator().reverse(); |
734 | } |
735 | |
736 | // Remove 'this' from the instruction list and deallocate it. |
737 | // |
738 | // Invariant: no outputs of 'this' may have any uses. |
739 | // |
740 | // Given: %2 = f(%1) |
741 | // %3 = g(%1) |
742 | // Execute: %2.destroy() |
743 | // Result: %3 = g(%1) |
744 | void destroy(); |
745 | |
746 | // Dynamically cast this node to the subclass indicated by the |
747 | // template variable, returning nullptr if the cast is invalid.. |
748 | // |
749 | // Example usage: if(auto s = n.cast<Select>()) { ... } |
750 | template <typename T> |
751 | T* cast() { |
752 | if (T::Kind == kind()) { |
753 | return static_cast<T*>(this); |
754 | } |
755 | return nullptr; |
756 | } |
757 | template <typename T> |
758 | const T* cast() const { |
759 | if (T::Kind == kind()) { |
760 | return static_cast<const T*>(this); |
761 | } |
762 | return nullptr; |
763 | } |
764 | |
765 | template <typename T> |
766 | T* expect() { |
767 | TORCH_CHECK( |
768 | T::Kind == kind(), |
769 | "expected a " , |
770 | T::Kind.toDisplayString(), |
771 | " but found a " , |
772 | kind().toDisplayString()); |
773 | return static_cast<T*>(this); |
774 | } |
775 | |
776 | bool matches(const FunctionSchema& schema) const; |
777 | |
778 | // XXX: this function is meant to be used with string literals only! |
779 | bool matches( |
780 | const char* signature_literal, |
781 | at::ArrayRef<Symbol> const_inputs = {}) const; |
782 | |
783 | bool isMemberOf(const OperatorSet& os) const; |
784 | template <typename T> |
785 | bool isMemberOf(const OperatorMap<T>& om) const { |
786 | auto it = om.map.find(kind()); |
787 | if (it == om.map.end()) { |
788 | return false; |
789 | } |
790 | for (auto& op : it->second) { |
791 | if (matches(op.first->schema())) { |
792 | return true; |
793 | } |
794 | } |
795 | return false; |
796 | } |
797 | |
798 | const FunctionSchema& schema() const; |
799 | const FunctionSchema* maybeSchema() const; |
800 | const Operator& getOperator() const; |
801 | Operation getOperation() const; |
802 | |
803 | const Operator* maybeOperator() const; |
804 | |
805 | void dump() const; |
806 | |
807 | std::ostream& print( |
808 | std::ostream& out, |
809 | size_t level, |
810 | std::vector<const Node*>* groups, |
811 | bool print_source_locations = true, |
812 | bool print_attributes = true, |
813 | bool print_scopes = true, |
814 | bool print_body = true) const; |
815 | |
816 | virtual ~Node() { |
817 | if (wrap_) { |
818 | wrap_->clear(); |
819 | } |
820 | } |
821 | |
822 | // Methods for accessing attributes |
823 | Node* copyAttributes(const Node& rhs) { |
824 | values_.clear(); |
825 | for (const AVPtr& i : rhs.values_) { |
826 | values_.push_back(i->clone()); |
827 | } |
828 | return this; |
829 | } |
830 | bool hasAttribute(Symbol name) const { |
831 | AT_ASSERT(name.is_attr()); |
832 | return findAttr(name, false) != values_.end(); |
833 | } |
834 | bool hasAttributeS(const std::string& name) const { |
835 | return hasAttribute(Symbol::attr(name)); |
836 | } |
837 | AttributeKind kindOf(Symbol name) const { |
838 | AT_ASSERT(name.is_attr()); |
839 | return (*findAttr(name, true))->kind(); |
840 | } |
841 | AttributeKind kindOfS(const std::string& name) const { |
842 | return kindOf(Symbol::attr(name)); |
843 | } |
844 | Node* removeAttribute(Symbol name) { |
845 | AT_ASSERT(name.is_attr()); |
846 | values_.erase(findAttr(name, true)); |
847 | return this; |
848 | } |
849 | Node* removeAttributeS(const std::string& name) { |
850 | return removeAttribute(Symbol::attr(name)); |
851 | } |
852 | bool hasAttributes() const { |
853 | return !values_.empty(); |
854 | } |
855 | size_t numAttributes() const { |
856 | return values_.size(); |
857 | } |
858 | // The names are returned in order, since name actually is the index. |
859 | std::vector<Symbol> attributeNames() const { |
860 | std::vector<Symbol> names; |
861 | names.reserve(values_.size()); |
862 | for (const AVPtr& a : values_) { |
863 | names.push_back(a->name); |
864 | } |
865 | return names; |
866 | } |
867 | std::vector<const char*> attributeNamesS() const { |
868 | std::vector<const char*> names; |
869 | names.reserve(values_.size()); |
870 | for (const AVPtr& a : values_) { |
871 | names.push_back(a->name.toUnqualString()); |
872 | } |
873 | return names; |
874 | } |
875 | |
876 | #define CREATE_ACCESSOR(Kind, method) \ |
877 | Node* method##_(Symbol name, Kind##Attr::ConstructorType v) { \ |
878 | return setAttr<Kind##Attr>( \ |
879 | name, std::forward<Kind##Attr::ConstructorType>(v)); \ |
880 | } \ |
881 | const Kind##Attr::ValueType& method(Symbol name) const { \ |
882 | return getAttr<Kind##Attr>(name); \ |
883 | } |
884 | |
885 | CREATE_ACCESSOR(Float, f) |
886 | CREATE_ACCESSOR(Complex, c) |
887 | CREATE_ACCESSOR(Floats, fs) |
888 | CREATE_ACCESSOR(ComplexVals, cs) |
889 | CREATE_ACCESSOR(String, s) |
890 | CREATE_ACCESSOR(Strings, ss) |
891 | CREATE_ACCESSOR(Int, i) |
892 | CREATE_ACCESSOR(Ints, is) |
893 | CREATE_ACCESSOR(Graph, g) |
894 | CREATE_ACCESSOR(Graphs, gs) |
895 | CREATE_ACCESSOR(Type, ty) |
896 | CREATE_ACCESSOR(Types, tys) |
897 | CREATE_ACCESSOR(IValue, ival) |
898 | |
899 | #undef CREATE_ACCESSOR |
900 | |
901 | // Our Graphs are not very const-correct, so we need to allow returning |
902 | // non-const references too |
903 | GraphAttr::ValueType& g(Symbol name) { |
904 | return getAttr<GraphAttr>(name); |
905 | } |
906 | |
907 | // does not use CREATE_ACCESSOR because we need additional asserts |
908 | Node* t_(Symbol name, TensorAttr::ConstructorType v) { |
909 | return setAttr<TensorAttr>( |
910 | name, std::forward<TensorAttr::ConstructorType>(v)); |
911 | } |
912 | const TensorAttr::ValueType& t(Symbol name) const { |
913 | return getAttr<TensorAttr>(name); |
914 | } |
915 | |
916 | Node* ts_(Symbol name, TensorsAttr::ConstructorType v) { |
917 | return setAttr<TensorsAttr>( |
918 | name, std::forward<TensorsAttr::ConstructorType>(v)); |
919 | } |
920 | const TensorsAttr::ValueType& ts(Symbol name) const { |
921 | return getAttr<TensorsAttr>(name); |
922 | } |
923 | |
924 | Block* findCommonAncestorBlockWith(Node* n); |
925 | |
926 | size_t blocksFromGraphBlock(); |
927 | |
928 | private: |
929 | void printAttrValue(std::ostream& out, const Symbol& name) const; |
930 | void printAttributes(std::ostream& out, bool ignore_subgraph) const; |
931 | |
932 | template <typename T> |
933 | Node* setAttr(Symbol name, typename T::ConstructorType v) { |
934 | AT_ASSERT(name.is_attr()); |
935 | auto it = findAttr(name, false); |
936 | auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v))); |
937 | // NOLINTNEXTLINE(bugprone-branch-clone) |
938 | if (it == values_.end()) { |
939 | values_.push_back(std::move(nv)); |
940 | } else { |
941 | *it = std::move(nv); |
942 | } |
943 | return this; |
944 | } |
945 | template <typename T> |
946 | typename T::ValueType& getAttr(Symbol name) const { |
947 | AT_ASSERT(name.is_attr()); |
948 | auto it = findAttr(name, true); |
949 | auto* child = dynamic_cast<T*>(it->get()); |
950 | if (child == nullptr) { |
951 | throw IRAttributeError(name, true); |
952 | } |
953 | return child->value(); |
954 | } |
955 | using AVPtr = AttributeValue::Ptr; |
956 | // NB: For determinism, we use a vector rather than a hash map. This does |
957 | // mean that lookups are O(n), so you shouldn't use Attributes to store |
958 | // a big pile of messages. |
959 | std::vector<AVPtr> values_; |
960 | std::vector<AVPtr>::iterator findAttr(Symbol name, bool required) { |
961 | AT_ASSERT(name.is_attr()); |
962 | auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { |
963 | return v->name == name; |
964 | }); |
965 | if (required && it == values_.end()) { |
966 | throw IRAttributeError(name, false); |
967 | } |
968 | AT_ASSERT(!required || it != values_.end()); |
969 | return it; |
970 | } |
971 | std::vector<AVPtr>::const_iterator findAttr(Symbol name, bool required) |
972 | const { |
973 | AT_ASSERT(name.is_attr()); |
974 | auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { |
975 | return v->name == name; |
976 | }); |
977 | if (required && it == values_.end()) { |
978 | throw IRAttributeError(name, false); |
979 | } |
980 | AT_ASSERT(!required || it != values_.end()); |
981 | return it; |
982 | } |
983 | |
984 | enum class MoveSide { BEFORE, AFTER }; |
985 | bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; |
986 | |
987 | std::pair<Value*, const Argument&> findInput(Symbol name); |
988 | // Lookup iterator in use list of _input i_ that corresponds to its use of |
989 | // _this_ |
990 | use_list::iterator findUseForInput(size_t i); |
991 | |
992 | // remove the use of input i, this sets input i to nullptr, but |
993 | // is only used internally to Node before setting it to a new value |
994 | // or erasing the entry from the list. |
995 | Value* dropInput(size_t i); |
996 | |
997 | bool inBlockList() const { |
998 | if (next() == nullptr) { |
999 | AT_ASSERT(prev() == nullptr); |
1000 | } |
1001 | return next() != nullptr; |
1002 | } |
1003 | |
1004 | void removeFromList(); |
1005 | void lint() const; |
1006 | |
1007 | void assignTopoPosition(); |
1008 | |
1009 | protected: |
1010 | // subclasses must override |
1011 | // this function is used by createClone to initialize a new version |
1012 | // of a node in another graph. It should allocate a new instance of the same |
1013 | // concrete type as 'this', but in graph 'g' which might be different |
1014 | // than graph_ |
1015 | virtual Node* allocNewInstance(Graph* g) { |
1016 | return new Node(g, kind()); |
1017 | } |
1018 | // create a copy of all properties of Node s into this. |
1019 | // subclasses should extend if they have additional information to copy. |
1020 | // 'this' will be allocated with s->allocNewInstance(g) so it should have |
1021 | // the same concrete type as 's' |
1022 | virtual void cloneFrom(Node* s); |
1023 | }; |
1024 | |
1025 | struct Block { |
1026 | friend struct Node; |
1027 | friend struct Graph; |
1028 | |
1029 | AT_DISALLOW_COPY_AND_ASSIGN(Block); |
1030 | TORCH_API Block(Graph* graph_, Node* node_); |
1031 | |
1032 | at::ArrayRef<Value*> inputs() { |
1033 | return input_->outputs(); |
1034 | } |
1035 | at::ArrayRef<const Value*> inputs() const { |
1036 | const auto& inputs = input_->outputs(); |
1037 | return {inputs.data(), inputs.size()}; |
1038 | } |
1039 | at::ArrayRef<Value*> outputs() { |
1040 | return output_->inputs(); |
1041 | } |
1042 | at::ArrayRef<const Value*> outputs() const { |
1043 | return static_cast<const Node*>(output_)->inputs(); |
1044 | } |
1045 | graph_node_list nodes() { |
1046 | return {input_, kNextDirection}; |
1047 | } |
1048 | const_graph_node_list nodes() const { |
1049 | return {input_, kNextDirection}; |
1050 | } |
1051 | Node* return_node() { |
1052 | return output_; |
1053 | } |
1054 | const Node* return_node() const { |
1055 | return output_; |
1056 | } |
1057 | Node* param_node() { |
1058 | return input_; |
1059 | } |
1060 | const Node* param_node() const { |
1061 | return input_; |
1062 | } |
1063 | /** |
1064 | * @warning NEVER pass raw pointer of smart pointer managed Graph to Python. |
1065 | * Check #87343 for details. |
1066 | */ |
1067 | Graph* owningGraph() { |
1068 | return graph_; |
1069 | } |
1070 | const Graph* owningGraph() const { |
1071 | return graph_; |
1072 | } |
1073 | Node* owningNode() { |
1074 | return owning_node_; |
1075 | } |
1076 | const Node* owningNode() const { |
1077 | return owning_node_; |
1078 | } |
1079 | |
1080 | Value* addInput(const std::string& name = "" ) { |
1081 | Value* v = input_->addOutput(); |
1082 | v->setDebugName(name); |
1083 | return v; |
1084 | } |
1085 | Value* insertInput(size_t i, const std::string& name = "" ) { |
1086 | Value* v = input_->insertOutput(i); |
1087 | v->setDebugName(name); |
1088 | return v; |
1089 | } |
1090 | void eraseInput(size_t i) { |
1091 | input_->eraseOutput(i); |
1092 | } |
1093 | void removeAllInputs() { |
1094 | input_->removeAllOutputs(); |
1095 | } |
1096 | size_t registerOutput(Value* v) { |
1097 | output_->addInput(v); |
1098 | return outputs().size() - 1; |
1099 | } |
1100 | size_t insertOutput(size_t i, Value* n) { |
1101 | output_->insertInput(i, n); |
1102 | return i; |
1103 | } |
1104 | void eraseOutput(size_t i) { |
1105 | output_->removeInput(i); |
1106 | } |
1107 | void removeAllOutputs() { |
1108 | output_->removeAllInputs(); |
1109 | } |
1110 | |
1111 | void replaceOutput(size_t i, Value* n) { |
1112 | output_->replaceInput(i, n); |
1113 | } |
1114 | void permuteOutputs(const std::vector<size_t>& new_inputs) { |
1115 | output_->permuteInputs(new_inputs); |
1116 | } |
1117 | void permuteInputs(const std::vector<size_t>& new_inputs) { |
1118 | input_->permuteOutputs(new_inputs); |
1119 | } |
1120 | |
1121 | Node* appendNode(Node* n) { |
1122 | AT_ASSERT(n->graph_ == graph_ && !n->inBlockList()); |
1123 | n->insertBefore(output_); |
1124 | return n; |
1125 | } |
1126 | Node* prependNode(Node* n) { |
1127 | AT_ASSERT(n->graph_ == graph_ && !n->inBlockList()); |
1128 | n->insertAfter(input_); |
1129 | return n; |
1130 | } |
1131 | |
1132 | // clone all inputs, nodes, and outputs from src and append them |
1133 | // to the inputs, nodes, and outputs of this block |
1134 | // value_map is used whenever a node in src references a free variable |
1135 | // in src to look up its corresponding value |
1136 | TORCH_API void cloneFrom(Block* src, std::function<Value*(Value*)> value_map); |
1137 | TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map); |
1138 | |
1139 | TORCH_API std::shared_ptr<Wrap<Block>> wrap() { |
1140 | if (!wrap_) { |
1141 | wrap_ = std::make_shared<Wrap<Block>>(this); |
1142 | } |
1143 | return wrap_; |
1144 | } |
1145 | |
1146 | virtual ~Block() { |
1147 | if (wrap_) { |
1148 | wrap_->clear(); |
1149 | } |
1150 | } |
1151 | |
1152 | void clear() { |
1153 | removeAllOutputs(); |
1154 | for (auto it = nodes().rbegin(); it != nodes().rend(); it++) { |
1155 | it.destroyCurrent(); |
1156 | } |
1157 | removeAllInputs(); |
1158 | } |
1159 | |
1160 | private: |
1161 | void reIndexTopology(); |
1162 | |
1163 | // get rid of all nodes |
1164 | // destroys in reverse order so that uses internal to this block |
1165 | // do not have to be removed before you can destroy the block |
1166 | void destroy(); |
1167 | |
1168 | Graph* const graph_; |
1169 | // holds outputs in a way that can be reflected |
1170 | // as a Use object |
1171 | // also used as the beginning/end of the circular node list to avoid |
1172 | // having corner cases where the list is empty. |
1173 | Node* const output_; |
1174 | Node* const input_; |
1175 | Node* const |
1176 | owning_node_; // either the node that has this block or nullptr for root |
1177 | // a managing wrapper for Python to allow invalidation |
1178 | std::shared_ptr<Wrap<Block>> wrap_; |
1179 | }; |
1180 | |
1181 | struct Graph : std::enable_shared_from_this<Graph> { |
1182 | AT_DISALLOW_COPY_AND_ASSIGN(Graph); |
1183 | friend struct Node; |
1184 | friend struct Value; |
1185 | friend struct Block; |
1186 | |
1187 | private: |
1188 | // only used to keep track of allocated nodes |
1189 | // actual representation of Graph is done with |
1190 | // inputs, outputs, nodes |
1191 | |
1192 | std::unordered_set<const Node*> all_nodes; |
1193 | std::unordered_set<const Value*> all_values; |
1194 | std::unordered_set<const Block*> all_blocks; |
1195 | size_t next_unique_; |
1196 | |
1197 | std::unordered_map<std::string, Value*> unique_names_; |
1198 | // name_base_suffix tracks largest suffix currently used by all names sharing |
1199 | // same name_base. Key of this map is name_base, value is largest suffix |
1200 | // numeric value. |
1201 | std::unordered_map<std::string, size_t> name_base_suffix_; |
1202 | |
1203 | ScopePtr current_scope_; |
1204 | |
1205 | Block* const block_; |
1206 | // when insertNode() is called, the node is inserted before this node |
1207 | // by default this is set to append to the top level block |
1208 | Node* insert_before_; |
1209 | |
1210 | c10::optional<size_t> op_version_; |
1211 | |
1212 | public: |
1213 | Graph(ScopePtr scope_root = c10::make_intrusive<Scope>()) |
1214 | : next_unique_(0), |
1215 | current_scope_(std::move(scope_root)), |
1216 | block_(new Block(this, nullptr)), |
1217 | insert_before_(return_node()) {} |
1218 | |
1219 | at::ArrayRef<Value*> inputs() { |
1220 | return block_->inputs(); |
1221 | } |
1222 | at::ArrayRef<const Value*> inputs() const { |
1223 | const Block& block = *block_; |
1224 | return block.inputs(); |
1225 | } |
1226 | at::ArrayRef<Value*> outputs() { |
1227 | return block_->outputs(); |
1228 | } |
1229 | at::ArrayRef<const Value*> outputs() const { |
1230 | const Block& block = *block_; |
1231 | return block.outputs(); |
1232 | } |
1233 | graph_node_list nodes() { |
1234 | return block_->nodes(); |
1235 | } |
1236 | const_graph_node_list nodes() const { |
1237 | const Block& block = *block_; |
1238 | return block.nodes(); |
1239 | } |
1240 | Node* param_node() { |
1241 | return block_->param_node(); |
1242 | } |
1243 | const Node* param_node() const { |
1244 | return block_->param_node(); |
1245 | } |
1246 | Node* return_node() { |
1247 | return block_->return_node(); |
1248 | } |
1249 | const Node* return_node() const { |
1250 | return block_->return_node(); |
1251 | } |
1252 | const std::unordered_map<std::string, Value*>& debugNames() const { |
1253 | return unique_names_; |
1254 | } |
1255 | |
1256 | TORCH_API void push_scope(const std::string& scope_name); |
1257 | TORCH_API void pop_scope(); |
1258 | |
1259 | ScopePtr current_scope() { |
1260 | return current_scope_; |
1261 | } |
1262 | |
1263 | void set_op_version(c10::optional<size_t> version) { |
1264 | op_version_ = version; |
1265 | } |
1266 | |
1267 | c10::optional<size_t> get_op_version() { |
1268 | return op_version_; |
1269 | } |
1270 | |
1271 | void set_current_scope(ScopePtr scope) { |
1272 | current_scope_ = std::move(scope); |
1273 | } |
1274 | |
1275 | Value* addInput(const std::string& name = "" ) { |
1276 | return block_->addInput(name); |
1277 | } |
1278 | Value* insertInput(size_t i, const std::string& name = "" ) { |
1279 | return block_->insertInput(i, name); |
1280 | } |
1281 | void eraseInput(size_t i) { |
1282 | block_->eraseInput(i); |
1283 | } |
1284 | size_t registerOutput(Value* n) { |
1285 | return block_->registerOutput(n); |
1286 | } |
1287 | void eraseOutput(size_t i) { |
1288 | block_->eraseOutput(i); |
1289 | } |
1290 | |
1291 | TORCH_API Node* create(NodeKind kind, size_t num_outputs = 1); |
1292 | TORCH_API Node* create( |
1293 | NodeKind kind, |
1294 | ArrayRef<Value*> inputs, |
1295 | size_t num_outputs = 1); |
1296 | |
1297 | TORCH_API Node* createNone(); |
1298 | TORCH_API Node* createAutogradZero(); |
1299 | TORCH_API Node* createUninitialized(TypePtr typ); |
1300 | TORCH_API Node* createWithSubgraph(Symbol kind); |
1301 | TORCH_API Node* createDifferentiableSubgraph(); |
1302 | TORCH_API Node* createTuple( |
1303 | at::ArrayRef<Value*> values, |
1304 | TupleTypePtr optional_named_tuple = nullptr); |
1305 | TORCH_API Node* createTupleUnpack(Value* v); |
1306 | TORCH_API Node* createTupleIndex( |
1307 | Value* tup, |
1308 | Value* idx, |
1309 | const TypePtr& output_type); |
1310 | TORCH_API Node* createTupleSlice( |
1311 | Value* tup, |
1312 | int64_t beg, |
1313 | int64_t step_size, |
1314 | int64_t num_values); |
1315 | TORCH_API Node* createEnumName(Value* e); |
1316 | TORCH_API Node* createEnumValue(Value* e); |
1317 | TORCH_API Node* createList( |
1318 | const TypePtr& contained_type, |
1319 | at::ArrayRef<Value*> values); |
1320 | TORCH_API Node* createListUnpack(Value* v, size_t size); |
1321 | TORCH_API Node* createDict( |
1322 | const TypePtr& key_type, |
1323 | const TypePtr& value_type, |
1324 | at::ArrayRef<Value*> keys, |
1325 | at::ArrayRef<Value*> values); |
1326 | TORCH_API Node* createNumToTensor(Value* value); |
1327 | TORCH_API Node* createObject(const ClassTypePtr& type); |
1328 | TORCH_API Node* createSetAttr( |
1329 | Value* obj, |
1330 | const std::string& field, |
1331 | Value* newValue); |
1332 | TORCH_API Node* createGetAttr(Value* obj, const std::string& field); |
1333 | Value* insertGetAttr(Value* obj, const std::string& field) { |
1334 | return insertNode(createGetAttr(obj, field))->output(); |
1335 | } |
1336 | TORCH_API Node* createStore(const std::string& name, Value* v); |
1337 | TORCH_API Node* createLoad(const std::string& name, const TypePtr& type); |
1338 | TORCH_API Node* createIsInstance(Value* v, at::ArrayRef<TypePtr> types); |
1339 | |
1340 | TORCH_API Value* insertUncheckedCast(Value* v, TypePtr type); |
1341 | |
1342 | // Insert a ToList operator with argument \p v and output type \p type. |
1343 | // \returns the output of the operation. |
1344 | TORCH_API Value* insertToList(Value* v, TypePtr type); |
1345 | |
1346 | TORCH_API Value* insertFunctionCall( |
1347 | Function* callee, |
1348 | const MatchedSchema& matched); |
1349 | TORCH_API Value* insertMethodCall( |
1350 | std::string method_name, |
1351 | const MatchedSchema& matched); |
1352 | |
1353 | // Note: defined in python_ir.cpp and can be used only in python extension |
1354 | Node* createPythonOp( |
1355 | THPObjectPtr&& pyobj, |
1356 | const std::string& cconv, |
1357 | pyobj_list&& scalar_args); |
1358 | // clone n, making a new node in _this_ graph. |
1359 | // use value_map to translate inputs of n to inputs of the cloned node |
1360 | // if copy_blocks is false, it will not recursively clone the nested blocks |
1361 | // this node contains. |
1362 | TORCH_API Node* createClone( |
1363 | Node* n, |
1364 | const std::function<Value*(Value*)>& value_map, |
1365 | bool copy_blocks = true); |
1366 | |
1367 | // Insert constant IValue into the graph. |
1368 | TORCH_API Value* insertConstant( |
1369 | const IValue& val, |
1370 | c10::optional<SourceRange> loc = c10::nullopt, |
1371 | c10::optional<ScopePtr> scope = c10::nullopt); |
1372 | |
1373 | // Schema-driven insert: |
1374 | // This inserts a node into the graph with inputs determined from args and |
1375 | // kwargs using Python argument matching rules, and checks that the op matches |
1376 | // a known schema. |
1377 | // |
1378 | // If this node successfully completes, it guarentees the node |
1379 | // is a correctly-formed invocation of opname |
1380 | TORCH_API Value* insert( |
1381 | Symbol opname, |
1382 | at::ArrayRef<NamedValue> args, |
1383 | at::ArrayRef<NamedValue> kwargs = {}, |
1384 | const c10::optional<SourceRange>& range = {}); |
1385 | |
1386 | Node* appendNode(Node* n) { |
1387 | return block_->appendNode(n); |
1388 | } |
1389 | |
1390 | Node* prependNode(Node* n) { |
1391 | return block_->prependNode(n); |
1392 | } |
1393 | |
1394 | // insert before insert_before_ node |
1395 | // initialized to insert at the end of the top level block |
1396 | // can be changed with setInsertPoint() |
1397 | Node* insertNode(Node* n) { |
1398 | AT_ASSERT( |
1399 | insert_before_->inBlockList() && |
1400 | "insert point node is no longer in a block list" ); |
1401 | return n->insertBefore(insert_before_); |
1402 | } |
1403 | // set where nodes are inserted to append to the end of this block |
1404 | void setInsertPoint(Block* b) { |
1405 | AT_ASSERT(b->owningGraph() == this); |
1406 | insert_before_ = b->return_node(); |
1407 | } |
1408 | // set where nodes are inserted to insert _before_ this node |
1409 | // for implementation simplicity we only support inserting before a node for |
1410 | // now |
1411 | void setInsertPoint(Node* n) { |
1412 | AT_ASSERT(n->owningGraph() == this && n->inBlockList()); |
1413 | insert_before_ = n; |
1414 | } |
1415 | Node* insertPoint() { |
1416 | return insert_before_; |
1417 | } |
1418 | |
1419 | // the top level block |
1420 | Block* block() { |
1421 | return block_; |
1422 | } |
1423 | const Block* block() const { |
1424 | return block_; |
1425 | } |
1426 | |
1427 | // Checks well-formedness and invariants of graph |
1428 | TORCH_API void lint() const; |
1429 | // for use in debugger |
1430 | TORCH_API void dump() const; |
1431 | |
1432 | TORCH_API ~Graph(); |
1433 | |
1434 | TORCH_API std::string toString(bool print_source_locations = true) const; |
1435 | |
1436 | TORCH_API std::ostream& print( |
1437 | std::ostream& out, |
1438 | bool print_source_locations = true) const; |
1439 | |
1440 | friend TORCH_API std::ostream& operator<<(std::ostream& out, const Graph& g); |
1441 | |
1442 | TORCH_API std::shared_ptr<Graph> copy(); |
1443 | TORCH_API std::unique_ptr<Graph> copyUnique(); |
1444 | TORCH_API void remapTypes(const std::function<TypePtr(TypePtr)>& type_map); |
1445 | |
1446 | private: |
1447 | friend TORCH_API void Lint(const AliasDb* db); |
1448 | TORCH_API void freeNode(Node* n); |
1449 | TORCH_API void freeValue(Value* v); |
1450 | TORCH_API void freeBlock(Block* b); |
1451 | void cloneFrom(Graph& src); |
1452 | }; |
1453 | |
1454 | /** \brief An utility class for setting temporary insertion points. |
1455 | * |
1456 | * When an object of this class is created, it stores the current insertion |
1457 | * point, sets the new one, and restores the original insertion point when the |
1458 | * object is destroyed. |
1459 | */ |
1460 | struct WithInsertPoint { |
1461 | WithInsertPoint(Node* n) : prev_(n->owningGraph()->insertPoint()) { |
1462 | n->owningGraph()->setInsertPoint(n); |
1463 | } |
1464 | WithInsertPoint(Block* b) : WithInsertPoint(b->return_node()) {} |
1465 | |
1466 | ~WithInsertPoint() { |
1467 | prev_->owningGraph()->setInsertPoint(prev_); |
1468 | } |
1469 | |
1470 | private: |
1471 | Node* prev_; |
1472 | }; |
1473 | |
1474 | /** \brief An utility class for setting temporary scopes. |
1475 | * |
1476 | * When an object of this class is created, it stores the current scope, sets |
1477 | * the new one, and restores the original scope when the object is destroyed. |
1478 | */ |
1479 | struct WithCurrentScope { |
1480 | WithCurrentScope(Graph& g, ScopePtr scope) |
1481 | : graph_(&g), prev_scope_(g.current_scope()) { |
1482 | g.set_current_scope(std::move(scope)); |
1483 | } |
1484 | ~WithCurrentScope() { |
1485 | graph_->set_current_scope(prev_scope_); |
1486 | } |
1487 | |
1488 | private: |
1489 | Graph* graph_; |
1490 | ScopePtr prev_scope_; |
1491 | }; |
1492 | |
1493 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
1494 | inline Value::Value(Node* node_, size_t offset_) |
1495 | : node_(node_), |
1496 | offset_(offset_), |
1497 | unique_(node_->graph_->next_unique_++), |
1498 | type_(TensorType::get()) { |
1499 | node_->graph_->all_values.emplace(this); |
1500 | } |
1501 | |
1502 | inline Value* Value::setType(TypePtr type) { |
1503 | AT_ASSERT(type); |
1504 | if (auto dyn = type->castRaw<c10::DynamicType>()) { |
1505 | type = dyn->fallback(); |
1506 | } |
1507 | type_ = std::move(type); |
1508 | for (Use& use : uses_) { |
1509 | use.user->op_ = nullptr; |
1510 | } |
1511 | return this; |
1512 | } |
1513 | |
1514 | inline Graph* Value::owningGraph() { |
1515 | return node()->owningGraph(); |
1516 | } |
1517 | |
1518 | inline const Graph* Value::owningGraph() const { |
1519 | return node()->owningGraph(); |
1520 | } |
1521 | |
1522 | /************* All nodes not required to be defined before Graph **************/ |
1523 | struct ProfileOp : public Node { |
1524 | static const Symbol Kind; |
1525 | ProfileOp(Graph* graph, std::function<void(std::vector<IValue>&)> callback) |
1526 | : Node(graph, ::c10::prim::profile), callback_(std::move(callback)) {} |
1527 | |
1528 | void cloneFrom(Node* other_) override; |
1529 | Node* allocNewInstance(Graph* g) override; |
1530 | |
1531 | const std::function<void(std::vector<IValue>&)>& getCallback() const { |
1532 | return callback_; |
1533 | } |
1534 | |
1535 | void setCallback(std::function<void(std::vector<IValue>&)> callback) { |
1536 | callback_ = std::move(callback); |
1537 | } |
1538 | |
1539 | bool hasSeenTensor() const { |
1540 | return has_seen_tensor_; |
1541 | } |
1542 | |
1543 | void setHasSeenTensor(bool has_seen_tensor) { |
1544 | has_seen_tensor_ = has_seen_tensor; |
1545 | } |
1546 | |
1547 | private: |
1548 | std::function<void(std::vector<IValue>&)> callback_; |
1549 | bool has_seen_tensor_ = false; |
1550 | }; |
1551 | |
1552 | struct TORCH_API ProfileIValueOp : public Node { |
1553 | static const Symbol Kind; |
1554 | ProfileIValueOp( |
1555 | Graph* graph, |
1556 | std::function<void(std::vector<IValue>&)> callback) |
1557 | : Node(graph, ::c10::prim::profile_ivalue), |
1558 | callback_(std::move(callback)) {} |
1559 | |
1560 | void cloneFrom(Node* other_) override; |
1561 | Node* allocNewInstance(Graph* g) override; |
1562 | |
1563 | const std::function<void(std::vector<IValue>&)>& getCallback() const { |
1564 | return callback_; |
1565 | } |
1566 | |
1567 | void setCallback(std::function<void(std::vector<IValue>&)> callback) { |
1568 | callback_ = callback; |
1569 | } |
1570 | |
1571 | private: |
1572 | std::function<void(std::vector<IValue>&)> callback_; |
1573 | }; |
1574 | |
1575 | // execute a Python function, used for Ops we can't optimize but that we want to |
1576 | // optimize around |
1577 | // |
1578 | // Note: actual implementation (ConcretePythonOp) is defined in python_ir.cpp |
1579 | // which is not included in libtorch.so. We still include some bits and pieces |
1580 | // of PythonOp here to enable writing simple passes generically. In general, |
1581 | // python-aware bits need to be moved to the descendant classes. |
1582 | struct TORCH_API PythonOp : public Node { |
1583 | using Node::Node; |
1584 | |
1585 | virtual std::string name() const = 0; |
1586 | virtual void writeScalars(std::ostream& out) const = 0; |
1587 | void cloneFrom(Node* other_) override = 0; |
1588 | Node* allocNewInstance(Graph* g) override = 0; |
1589 | // recover the autograd.Function instance, if this PythonOp's function |
1590 | // was originally SomeFunction.apply |
1591 | // used in ONNX for discovering symbolics |
1592 | virtual c10::optional<THPObjectPtr> autogradFunction() const = 0; |
1593 | |
1594 | virtual void lint_python() const = 0; |
1595 | }; |
1596 | |
1597 | TORCH_API void LintGraph(const std::shared_ptr<Graph>& graph); |
1598 | |
1599 | TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v); |
1600 | |
1601 | /** Insert graph \p CALLEE into graph \p G using \p INPUTS as input values. |
1602 | * The insertion happens at the current insertion point. |
1603 | * Optionally, one can also pass \p VALUE_MAP to get a map between \p CALLEE |
1604 | * values and their cloned copies in \p G. |
1605 | */ |
1606 | TORCH_API std::vector<Value*> insertGraph( |
1607 | Graph& g, |
1608 | Graph& callee, |
1609 | ArrayRef<Value*> inputs); |
1610 | TORCH_API std::vector<Value*> insertGraph( |
1611 | Graph& g, |
1612 | Graph& callee, |
1613 | ArrayRef<Value*> inputs, |
1614 | std::unordered_map<Value*, Value*>& value_map); |
1615 | |
1616 | /** Insert function \p CALLEE after node \p TO_REPLACE, remove the node and |
1617 | * replace all its uses with corresponding outputs of the inserted function. |
1618 | * This asserts that the number of outputs of the original node and the |
1619 | * graph are the same. |
1620 | */ |
1621 | TORCH_API std::vector<Value*> inlineCallTo( |
1622 | Node* to_replace, |
1623 | GraphFunction* callee, |
1624 | bool use_graph = true); |
1625 | |
1626 | TORCH_API std::vector<Value*> inlineCallTo( |
1627 | Node* to_replace, |
1628 | GraphFunction* callee, |
1629 | Graph* callee_graph); |
1630 | |
1631 | /** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a |
1632 | * tuple unpack node and return the resulting values. |
1633 | */ |
1634 | TORCH_API std::vector<Value*> unpackOutputs(const std::vector<Value*>& outputs); |
1635 | |
1636 | TORCH_API std::vector<Node*> findAllNodes(Graph& g, Symbol kind, bool recurse); |
1637 | TORCH_API std::vector<Node*> findAllNodes(Block& b, Symbol kind, bool recurse); |
1638 | TORCH_API std::vector<Node*> findAllNodes( |
1639 | at::ArrayRef<Block*> a, |
1640 | Symbol kind, |
1641 | bool recurse); |
1642 | |
1643 | struct TORCH_API OperatorSet { |
1644 | OperatorSet(std::initializer_list<const char*> sig_literals); |
1645 | std::vector<std::shared_ptr<Operator>> getOps() const; |
1646 | void insert(std::initializer_list<const char*> sig_literals); |
1647 | |
1648 | private: |
1649 | friend struct Node; |
1650 | std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>> ops; |
1651 | }; |
1652 | |
1653 | template <typename T> |
1654 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
1655 | struct OperatorMap { |
1656 | // Type aliasing |
1657 | using OpMapType = typename std::pair<std::shared_ptr<Operator>, T>; |
1658 | using ValueType = std::vector<OpMapType>; |
1659 | using MapType = std::unordered_map<Symbol, ValueType>; |
1660 | |
1661 | OperatorMap() = default; |
1662 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
1663 | explicit OperatorMap( |
1664 | std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> init) { |
1665 | insert(init); |
1666 | } |
1667 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
1668 | explicit OperatorMap(std::initializer_list<std::pair<const char*, T>> init) { |
1669 | insert(init); |
1670 | } |
1671 | |
1672 | void insert(const std::shared_ptr<Operator>& op, T val) { |
1673 | // Remove if exists before insert |
1674 | erase(op); |
1675 | map[Symbol::fromQualString(op->schema().name())].emplace_back( |
1676 | std::make_pair(op, val)); |
1677 | } |
1678 | |
1679 | void insert(const OperatorSet& op_set, T val) { |
1680 | for (auto& op : op_set.getOps()) { |
1681 | insert(op, val); |
1682 | } |
1683 | } |
1684 | |
1685 | void insert( |
1686 | std::initializer_list<std::pair<std::shared_ptr<Operator>, T>> v) { |
1687 | for (auto& el : v) { |
1688 | insert(el.first, el.second); |
1689 | } |
1690 | } |
1691 | |
1692 | void insert(std::initializer_list<std::pair<const char*, T>> v) { |
1693 | for (auto& el : v) { |
1694 | insert(getOperatorForLiteral(el.first), el.second); |
1695 | } |
1696 | } |
1697 | |
1698 | void erase(const std::shared_ptr<Operator>& op) { |
1699 | auto it = map.find(Symbol::fromQualString(op->schema().name())); |
1700 | if (it == map.end()) { |
1701 | return; |
1702 | } |
1703 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1704 | if (vit->first->schema() == op->schema()) { |
1705 | it->second.erase(vit); |
1706 | break; |
1707 | } |
1708 | } |
1709 | if (it->second.size() == 0) { |
1710 | map.erase(Symbol::fromQualString(op->schema().name())); |
1711 | } |
1712 | } |
1713 | |
1714 | bool contains(const Operator& op) const { |
1715 | const auto it = map.find(Symbol::fromQualString(op.schema().name())); |
1716 | if (it == map.end()) { |
1717 | return false; |
1718 | } |
1719 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1720 | if (vit->first->schema() == op.schema()) { |
1721 | return true; |
1722 | } |
1723 | } |
1724 | return false; |
1725 | } |
1726 | |
1727 | bool contains(const Node* n) const { |
1728 | return n->maybeOperator() && contains(n->getOperator()); |
1729 | } |
1730 | |
1731 | c10::optional<T> find(const Operator& op) { |
1732 | const auto it = map.find(Symbol::fromQualString(op.schema().name())); |
1733 | if (it == map.end()) { |
1734 | return c10::nullopt; |
1735 | } |
1736 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1737 | if (vit->first->schema() == op.schema()) { |
1738 | return vit->second; |
1739 | } |
1740 | } |
1741 | return c10::nullopt; |
1742 | } |
1743 | |
1744 | // TODO: return iterator |
1745 | std::vector<OpMapType> getAllKeysAndValues() const { |
1746 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1747 | std::vector<OpMapType> keys_values; |
1748 | for (auto& symbol_mapping : map) { |
1749 | auto& vec = symbol_mapping.second; |
1750 | for (auto& pair : vec) { |
1751 | keys_values.push_back(pair); |
1752 | } |
1753 | } |
1754 | return keys_values; |
1755 | } |
1756 | |
1757 | private: |
1758 | friend struct Node; |
1759 | MapType map; |
1760 | }; |
1761 | |
1762 | template <typename T> |
1763 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
1764 | struct FunctionSchemaMap { |
1765 | // Type aliasing |
1766 | using FuncSchemaMapType = typename std::pair<FunctionSchema, T>; |
1767 | using ValueType = std::vector<FuncSchemaMapType>; |
1768 | using MapType = std::unordered_map<Symbol, ValueType>; |
1769 | |
1770 | FunctionSchemaMap() = default; |
1771 | void insert(const FunctionSchema& schema, T val) { |
1772 | // Remove if exists before insert |
1773 | erase(schema); |
1774 | map[Symbol::fromQualString(schema.name())].emplace_back( |
1775 | std::make_pair(schema, val)); |
1776 | } |
1777 | |
1778 | void erase(const FunctionSchema& schema) { |
1779 | auto it = map.find(Symbol::fromQualString(schema.name())); |
1780 | if (it == map.end()) { |
1781 | return; |
1782 | } |
1783 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1784 | if (vit->first == schema) { |
1785 | it->second.erase(vit); |
1786 | break; |
1787 | } |
1788 | } |
1789 | if (it->second.size() == 0) { |
1790 | map.erase(Symbol::fromQualString(schema.name())); |
1791 | } |
1792 | } |
1793 | |
1794 | bool contains(const FunctionSchema& schema) const { |
1795 | const auto it = map.find(Symbol::fromQualString(schema.name())); |
1796 | if (it == map.end()) { |
1797 | return false; |
1798 | } |
1799 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1800 | if (vit->first->schema() == schema) { |
1801 | return true; |
1802 | } |
1803 | } |
1804 | return false; |
1805 | } |
1806 | |
1807 | c10::optional<T> find(const FunctionSchema& schema) const { |
1808 | const auto it = map.find(Symbol::fromQualString(schema.name())); |
1809 | if (it == map.end()) { |
1810 | return c10::nullopt; |
1811 | } |
1812 | for (auto vit = it->second.begin(); vit != it->second.end(); ++vit) { |
1813 | if (vit->first == schema) { |
1814 | return vit->second; |
1815 | } |
1816 | } |
1817 | return c10::nullopt; |
1818 | } |
1819 | |
1820 | // TODO: return iterator |
1821 | std::vector<FuncSchemaMapType> getAllKeysAndValues() const { |
1822 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1823 | std::vector<FuncSchemaMapType> keys_values; |
1824 | for (auto& symbol_mapping : map) { |
1825 | auto& vec = symbol_mapping.second; |
1826 | for (auto& pair : vec) { |
1827 | keys_values.push_back(pair); |
1828 | } |
1829 | } |
1830 | return keys_values; |
1831 | } |
1832 | |
1833 | private: |
1834 | friend struct Node; |
1835 | MapType map; |
1836 | }; |
1837 | |
1838 | } // namespace jit |
1839 | } // namespace torch |
1840 | |