1 | #include <torch/csrc/lazy/core/ir_util.h> |
2 | |
3 | #include <c10/util/Logging.h> |
4 | |
5 | namespace torch { |
6 | namespace lazy { |
7 | |
8 | std::vector<const Node*> Util::ComputePostOrder( |
9 | const Node* node, |
10 | EmissionMap* emap) { |
11 | std::vector<const Node*> post_order; |
12 | std::vector<const Node*> queue; |
13 | queue.push_back(node); |
14 | while (!queue.empty()) { |
15 | node = queue.back(); |
16 | auto it = emap->find(node); |
17 | if (it == emap->end()) { |
18 | (*emap)[node] = kEmitting; |
19 | for (auto& output : node->operands()) { |
20 | auto oit = emap->find(output.node); |
21 | if (oit == emap->end()) { |
22 | queue.push_back(output.node); |
23 | } else { |
24 | TORCH_CHECK( |
25 | oit->second != kEmitting, |
26 | "Graph loop found at " , |
27 | output.node->ToString()); |
28 | } |
29 | } |
30 | } else if (it->second == kEmitting) { |
31 | for (auto& output : node->operands()) { |
32 | auto oit = emap->find(output.node); |
33 | TORCH_CHECK( |
34 | oit != emap->end() && oit->second == kEmitted, |
35 | "Graph loop found at " , |
36 | output.node->ToString()); |
37 | } |
38 | (*emap)[node] = kEmitted; |
39 | post_order.push_back(node); |
40 | queue.pop_back(); |
41 | } else { |
42 | TORCH_CHECK(it->second == kEmitted); |
43 | queue.pop_back(); |
44 | } |
45 | } |
46 | return post_order; |
47 | } |
48 | |
49 | std::vector<const Node*> Util::ComputePostOrder( |
50 | c10::ArrayRef<const Node*> nodes, |
51 | EmissionMap* emap) { |
52 | std::vector<const Node*> post_order; |
53 | for (auto node : nodes) { |
54 | auto node_post_order = ComputePostOrder(node, emap); |
55 | post_order.insert( |
56 | post_order.end(), node_post_order.begin(), node_post_order.end()); |
57 | } |
58 | return post_order; |
59 | } |
60 | |
61 | std::vector<const Node*> Util::ComputePostOrder( |
62 | c10::ArrayRef<const Node*> nodes) { |
63 | EmissionMap emap; |
64 | return ComputePostOrder(nodes, &emap); |
65 | } |
66 | |
67 | size_t Util::GetGraphSize(c10::ArrayRef<const Node*> nodes) { |
68 | return ComputePostOrder(nodes).size(); |
69 | } |
70 | |
71 | } // namespace lazy |
72 | } // namespace torch |
73 | |