1#include <torch/csrc/lazy/core/ir_util.h>
2
3#include <c10/util/Logging.h>
4
5namespace torch {
6namespace lazy {
7
8std::vector<const Node*> Util::ComputePostOrder(
9 const Node* node,
10 EmissionMap* emap) {
11 std::vector<const Node*> post_order;
12 std::vector<const Node*> queue;
13 queue.push_back(node);
14 while (!queue.empty()) {
15 node = queue.back();
16 auto it = emap->find(node);
17 if (it == emap->end()) {
18 (*emap)[node] = kEmitting;
19 for (auto& output : node->operands()) {
20 auto oit = emap->find(output.node);
21 if (oit == emap->end()) {
22 queue.push_back(output.node);
23 } else {
24 TORCH_CHECK(
25 oit->second != kEmitting,
26 "Graph loop found at ",
27 output.node->ToString());
28 }
29 }
30 } else if (it->second == kEmitting) {
31 for (auto& output : node->operands()) {
32 auto oit = emap->find(output.node);
33 TORCH_CHECK(
34 oit != emap->end() && oit->second == kEmitted,
35 "Graph loop found at ",
36 output.node->ToString());
37 }
38 (*emap)[node] = kEmitted;
39 post_order.push_back(node);
40 queue.pop_back();
41 } else {
42 TORCH_CHECK(it->second == kEmitted);
43 queue.pop_back();
44 }
45 }
46 return post_order;
47}
48
49std::vector<const Node*> Util::ComputePostOrder(
50 c10::ArrayRef<const Node*> nodes,
51 EmissionMap* emap) {
52 std::vector<const Node*> post_order;
53 for (auto node : nodes) {
54 auto node_post_order = ComputePostOrder(node, emap);
55 post_order.insert(
56 post_order.end(), node_post_order.begin(), node_post_order.end());
57 }
58 return post_order;
59}
60
61std::vector<const Node*> Util::ComputePostOrder(
62 c10::ArrayRef<const Node*> nodes) {
63 EmissionMap emap;
64 return ComputePostOrder(nodes, &emap);
65}
66
67size_t Util::GetGraphSize(c10::ArrayRef<const Node*> nodes) {
68 return ComputePostOrder(nodes).size();
69}
70
71} // namespace lazy
72} // namespace torch
73