1 | #pragma once |
2 | |
3 | #include <sstream> |
4 | |
5 | #include <torch/csrc/api/include/torch/jit.h> |
6 | #include <torch/csrc/jit/runtime/graph_executor.h> |
7 | #include <torch/csrc/lazy/backend/lowering_context.h> |
8 | #include <torch/csrc/lazy/core/ir.h> |
9 | #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | |
14 | using TSOpVector = std::vector<torch::jit::Value*>; |
15 | |
16 | class TORCH_API TSComputation : public Computation { |
17 | public: |
18 | TSComputation(const std::shared_ptr<torch::jit::Graph>& graph) |
19 | : graph_(graph), graph_executor_(graph, "" ) { |
20 | for (torch::jit::Value* input : graph_->inputs()) { |
21 | parameter_names_.push_back(input->debugName()); |
22 | } |
23 | } |
24 | |
25 | int parameters_size() const override { |
26 | return parameter_names_.size(); |
27 | } |
28 | |
29 | const std::vector<Shape>& parameter_shapes() const override { |
30 | throw std::runtime_error( |
31 | "TODO(whc) implement TS computation shapes or change interface" ); |
32 | return parameter_shapes_; |
33 | } |
34 | |
35 | const std::vector<std::string>& parameter_names() const override { |
36 | return parameter_names_; |
37 | } |
38 | |
39 | const Shape& result_shape() const override { |
40 | throw std::runtime_error( |
41 | "TODO(whc) implement TS computation shapes or change interface" ); |
42 | return result_shape_; |
43 | } |
44 | |
45 | const std::string to_string() const override { |
46 | std::ostringstream oss; |
47 | oss << *graph_; |
48 | return oss.str(); |
49 | } |
50 | |
51 | std::shared_ptr<torch::jit::Graph> graph() const { |
52 | return graph_; |
53 | } |
54 | |
55 | torch::jit::GraphExecutor& graph_executor() { |
56 | return graph_executor_; |
57 | } |
58 | |
59 | private: |
60 | std::shared_ptr<torch::jit::Graph> graph_; |
61 | torch::jit::GraphExecutor graph_executor_; |
62 | std::vector<std::string> parameter_names_; |
63 | std::vector<Shape> parameter_shapes_; |
64 | Shape result_shape_; |
65 | }; |
66 | |
67 | class TORCH_API TSLoweringContext : public LoweringContext { |
68 | public: |
69 | TSLoweringContext(const std::string& name, const BackendDevice device); |
70 | |
71 | TSLoweringContext( |
72 | const std::string& name, |
73 | BackendDevice device, |
74 | c10::ArrayRef<const Node*> post_order, |
75 | Util::EmissionMap emit_status); |
76 | |
77 | size_t AddResult(const Output& output) override { |
78 | return AddResult(GetOutputOp(output)); |
79 | } |
80 | |
81 | void AddParameter( |
82 | const torch::lazy::Output& output, |
83 | size_t index, |
84 | const Shape& shape, |
85 | const std::string& name) override { |
86 | TORCH_INTERNAL_ASSERT(false, "not implemented" ); |
87 | } |
88 | |
89 | void Lower(const Node* node); |
90 | |
91 | ComputationPtr Build() override { |
92 | for (torch::jit::Value* output : root_tuple_) { |
93 | graph_->block()->registerOutput(output); |
94 | } |
95 | return std::shared_ptr<Computation>(new TSComputation(graph_)); |
96 | } |
97 | |
98 | // Retrieves the lowered operation for an output. If the requested output is |
99 | // not available yet, the graph behind the output's Node is lowered, and the |
100 | // corresponding TS operation returned. |
101 | torch::jit::Value* GetOutputOp(const Output& output) { |
102 | auto it = emitted_outputs_.find(output); |
103 | if (it == emitted_outputs_.end()) { |
104 | auto post_order = Util::ComputePostOrder(output.node, &emit_status_); |
105 | for (auto node : post_order) { |
106 | Lower(node); |
107 | } |
108 | // At this point the output better be present, otherwise there is an issue |
109 | // with the lowering code. |
110 | it = emitted_outputs_.find(output); |
111 | TORCH_CHECK( |
112 | it != emitted_outputs_.end(), |
113 | "No TS operation emitted for output: " , |
114 | output.ToString()); |
115 | } |
116 | return it->second; |
117 | } |
118 | |
119 | // Assigns the given TS operation to the specified output. As outputs are |
120 | // lowered in a post-order fashion, later nodes should always find their |
121 | // operands among the emitted outputs. |
122 | void AssignOutputOp(const Output& output, torch::jit::Value* op); |
123 | |
124 | // If a parameter associated with data has already been declared, it will be |
125 | // returned. Otherwise a new one will be created, associated with the tensor |
126 | // held in data. |
127 | torch::jit::Value* GetParameter(BackendDataPtr data); |
128 | |
129 | std::shared_ptr<torch::jit::Graph> graph() const { |
130 | return graph_; |
131 | } |
132 | |
133 | private: |
134 | struct Parameter { |
135 | torch::jit::Value* param{nullptr}; |
136 | size_t index = 0; |
137 | }; |
138 | |
139 | size_t AddResult(torch::jit::Value* op) { |
140 | root_tuple_.push_back(std::move(op)); |
141 | return root_tuple_.size() - 1; |
142 | } |
143 | |
144 | std::shared_ptr<torch::jit::Graph> graph_; |
145 | std::shared_ptr<torch::jit::GraphFunction> function_; |
146 | std::unordered_map<BackendData::Handle, Parameter> parameters_map_; |
147 | std::vector<torch::jit::Value*> root_tuple_; |
148 | OutputMap<torch::jit::Value*> emitted_outputs_; |
149 | }; |
150 | |
151 | } // namespace lazy |
152 | } // namespace torch |
153 | |