1#pragma once
2
3#include <unordered_map>
4#include <vector>
5
6#include <torch/csrc/lazy/core/ir.h>
7
8namespace torch {
9namespace lazy {
10
11class TORCH_API Util {
12 public:
13 // Tracks the emission status of the nodes during the post-order generation.
14 // It helps tracking loops within the computation graphs.
15 enum EmitStatus {
16 kNotEmitted,
17 kEmitting,
18 kEmitted,
19 };
20
21 using EmissionMap = std::unordered_map<const Node*, EmitStatus>;
22
23 // Computes the post order from the given node, without using recursion. The
24 // emission map can be used as saved state, for multiple separate calls to
25 // this API. The returned post-order can be empty if the node has already been
26 // emitted inside the emission map. An error is generated if a loop is
27 // detected.
28 static std::vector<const Node*> ComputePostOrder(
29 const Node* node,
30 EmissionMap* emap);
31
32 static std::vector<const Node*> ComputePostOrder(
33 c10::ArrayRef<const Node*> nodes,
34 EmissionMap* emap);
35
36 // Same as above, but computes the post order on the set of nodes specified as
37 // argument.
38 static std::vector<const Node*> ComputePostOrder(
39 c10::ArrayRef<const Node*> nodes);
40
41 // Retrieves the number of nodes within the graph whose sink are passed in the
42 // nodes argument.
43 static size_t GetGraphSize(c10::ArrayRef<const Node*> nodes);
44};
45
46} // namespace lazy
47} // namespace torch
48