1 | #pragma once |
2 | |
3 | #include <memory> |
4 | #include <string> |
5 | #include <unordered_map> |
6 | #include <utility> |
7 | #include <vector> |
8 | |
9 | #include <torch/csrc/lazy/backend/backend_data.h> |
10 | #include <torch/csrc/lazy/backend/backend_device.h> |
11 | #include <torch/csrc/lazy/core/ir.h> |
12 | #include <torch/csrc/lazy/core/ir_util.h> |
13 | |
14 | namespace torch { |
15 | namespace lazy { |
16 | |
17 | class TORCH_API Computation { |
18 | public: |
19 | virtual int parameters_size() const = 0; |
20 | |
21 | virtual const std::vector<Shape>& parameter_shapes() const = 0; |
22 | |
23 | virtual const std::vector<std::string>& parameter_names() const = 0; |
24 | |
25 | virtual const Shape& result_shape() const = 0; |
26 | |
27 | virtual const std::string to_string() const = 0; |
28 | |
29 | virtual ~Computation() = default; |
30 | |
31 | // Indicates whether this computation is being executed inside a mark step |
32 | // Assume false unless set otherwise |
33 | bool in_mark_step = false; |
34 | }; |
35 | |
36 | using ComputationPtr = std::shared_ptr<Computation>; |
37 | |
38 | // Keeps track of the code generation state. |
39 | class TORCH_API LoweringContext { |
40 | public: |
41 | LoweringContext(const std::string& name, BackendDevice device); |
42 | LoweringContext( |
43 | const std::string& name, |
44 | BackendDevice device, |
45 | c10::ArrayRef<const torch::lazy::Node*> post_order, |
46 | Util::EmissionMap emit_status); |
47 | |
48 | virtual ~LoweringContext() = default; |
49 | |
50 | static std::unique_ptr<LoweringContext> Create( |
51 | const std::string& name, |
52 | BackendDevice device, |
53 | c10::ArrayRef<const torch::lazy::Node*> post_order, |
54 | Util::EmissionMap emit_status); |
55 | |
56 | static std::unique_ptr<LoweringContext> Create( |
57 | const std::string& name, |
58 | BackendDevice device); |
59 | |
60 | const BackendDevice& device() const { |
61 | return device_; |
62 | }; |
63 | |
64 | // Retrieves the vector holding all the tensors associated with the parameter |
65 | // instructions which have been created. |
66 | const std::vector<BackendDataPtr>& GetParametersData() const; |
67 | |
68 | // Adds a new input/output alias. |
69 | virtual void SetUpAlias( |
70 | const std::vector<int64_t>& output_index, |
71 | int64_t param_number, |
72 | const std::vector<int64_t>& param_index, |
73 | bool must_alias = false) { |
74 | // Dummy default implementation to do nothing. |
75 | } |
76 | |
77 | // Check if parameter shape matches result at index. |
78 | virtual bool CheckResultShape( |
79 | const BackendDataPtr& parameter_data, |
80 | size_t result_idx) { |
81 | // Dummy default implementation to do nothing. |
82 | return false; |
83 | } |
84 | |
85 | // Adds the given output as a component of the result tuple and returns its |
86 | // assigned position within the tuple. |
87 | virtual size_t AddResult(const torch::lazy::Output& output) = 0; |
88 | |
89 | // Associates the given output with the input parameter of the given index and |
90 | // shape. Only used for the operator-by-operator execution, mostly for |
91 | // debugging purposes. |
92 | virtual void AddParameter( |
93 | const torch::lazy::Output& output, |
94 | size_t index, |
95 | const Shape& shape, |
96 | const std::string& name) = 0; |
97 | |
98 | // Build the computation capturing all the operations created with the |
99 | // embedded builder (returned by the builder() API). |
100 | virtual ComputationPtr Build() = 0; |
101 | |
102 | size_t GetEmittedNodeCount() const { |
103 | return emit_status_.size(); |
104 | } |
105 | |
106 | protected: |
107 | BackendDevice device_; |
108 | std::vector<BackendDataPtr> parameters_; |
109 | std::vector<size_t> parameter_sequence_; |
110 | Util::EmissionMap emit_status_; |
111 | }; |
112 | |
113 | } // namespace lazy |
114 | } // namespace torch |
115 | |