1 | #pragma once |
2 | |
3 | #include <unordered_map> |
4 | #include <vector> |
5 | |
6 | #include <torch/csrc/lazy/core/ir.h> |
7 | |
8 | namespace torch { |
9 | namespace lazy { |
10 | |
11 | class TORCH_API Util { |
12 | public: |
13 | // Tracks the emission status of the nodes during the post-order generation. |
14 | // It helps tracking loops within the computation graphs. |
15 | enum EmitStatus { |
16 | kNotEmitted, |
17 | kEmitting, |
18 | kEmitted, |
19 | }; |
20 | |
21 | using EmissionMap = std::unordered_map<const Node*, EmitStatus>; |
22 | |
23 | // Computes the post order from the given node, without using recursion. The |
24 | // emission map can be used as saved state, for multiple separate calls to |
25 | // this API. The returned post-order can be empty if the node has already been |
26 | // emitted inside the emission map. An error is generated if a loop is |
27 | // detected. |
28 | static std::vector<const Node*> ComputePostOrder( |
29 | const Node* node, |
30 | EmissionMap* emap); |
31 | |
32 | static std::vector<const Node*> ComputePostOrder( |
33 | c10::ArrayRef<const Node*> nodes, |
34 | EmissionMap* emap); |
35 | |
36 | // Same as above, but computes the post order on the set of nodes specified as |
37 | // argument. |
38 | static std::vector<const Node*> ComputePostOrder( |
39 | c10::ArrayRef<const Node*> nodes); |
40 | |
41 | // Retrieves the number of nodes within the graph whose sink are passed in the |
42 | // nodes argument. |
43 | static size_t GetGraphSize(c10::ArrayRef<const Node*> nodes); |
44 | }; |
45 | |
46 | } // namespace lazy |
47 | } // namespace torch |
48 | |