1#pragma once
2
3#include <sstream>
4
5#include <torch/csrc/api/include/torch/jit.h>
6#include <torch/csrc/jit/runtime/graph_executor.h>
7#include <torch/csrc/lazy/backend/lowering_context.h>
8#include <torch/csrc/lazy/core/ir.h>
9#include <torch/csrc/lazy/ts_backend/ts_node_lowering.h>
10
11namespace torch {
12namespace lazy {
13
14using TSOpVector = std::vector<torch::jit::Value*>;
15
16class TORCH_API TSComputation : public Computation {
17 public:
18 TSComputation(const std::shared_ptr<torch::jit::Graph>& graph)
19 : graph_(graph), graph_executor_(graph, "") {
20 for (torch::jit::Value* input : graph_->inputs()) {
21 parameter_names_.push_back(input->debugName());
22 }
23 }
24
25 int parameters_size() const override {
26 return parameter_names_.size();
27 }
28
29 const std::vector<Shape>& parameter_shapes() const override {
30 throw std::runtime_error(
31 "TODO(whc) implement TS computation shapes or change interface");
32 return parameter_shapes_;
33 }
34
35 const std::vector<std::string>& parameter_names() const override {
36 return parameter_names_;
37 }
38
39 const Shape& result_shape() const override {
40 throw std::runtime_error(
41 "TODO(whc) implement TS computation shapes or change interface");
42 return result_shape_;
43 }
44
45 const std::string to_string() const override {
46 std::ostringstream oss;
47 oss << *graph_;
48 return oss.str();
49 }
50
51 std::shared_ptr<torch::jit::Graph> graph() const {
52 return graph_;
53 }
54
55 torch::jit::GraphExecutor& graph_executor() {
56 return graph_executor_;
57 }
58
59 private:
60 std::shared_ptr<torch::jit::Graph> graph_;
61 torch::jit::GraphExecutor graph_executor_;
62 std::vector<std::string> parameter_names_;
63 std::vector<Shape> parameter_shapes_;
64 Shape result_shape_;
65};
66
67class TORCH_API TSLoweringContext : public LoweringContext {
68 public:
69 TSLoweringContext(const std::string& name, const BackendDevice device);
70
71 TSLoweringContext(
72 const std::string& name,
73 BackendDevice device,
74 c10::ArrayRef<const Node*> post_order,
75 Util::EmissionMap emit_status);
76
77 size_t AddResult(const Output& output) override {
78 return AddResult(GetOutputOp(output));
79 }
80
81 void AddParameter(
82 const torch::lazy::Output& output,
83 size_t index,
84 const Shape& shape,
85 const std::string& name) override {
86 TORCH_INTERNAL_ASSERT(false, "not implemented");
87 }
88
89 void Lower(const Node* node);
90
91 ComputationPtr Build() override {
92 for (torch::jit::Value* output : root_tuple_) {
93 graph_->block()->registerOutput(output);
94 }
95 return std::shared_ptr<Computation>(new TSComputation(graph_));
96 }
97
98 // Retrieves the lowered operation for an output. If the requested output is
99 // not available yet, the graph behind the output's Node is lowered, and the
100 // corresponding TS operation returned.
101 torch::jit::Value* GetOutputOp(const Output& output) {
102 auto it = emitted_outputs_.find(output);
103 if (it == emitted_outputs_.end()) {
104 auto post_order = Util::ComputePostOrder(output.node, &emit_status_);
105 for (auto node : post_order) {
106 Lower(node);
107 }
108 // At this point the output better be present, otherwise there is an issue
109 // with the lowering code.
110 it = emitted_outputs_.find(output);
111 TORCH_CHECK(
112 it != emitted_outputs_.end(),
113 "No TS operation emitted for output: ",
114 output.ToString());
115 }
116 return it->second;
117 }
118
119 // Assigns the given TS operation to the specified output. As outputs are
120 // lowered in a post-order fashion, later nodes should always find their
121 // operands among the emitted outputs.
122 void AssignOutputOp(const Output& output, torch::jit::Value* op);
123
124 // If a parameter associated with data has already been declared, it will be
125 // returned. Otherwise a new one will be created, associated with the tensor
126 // held in data.
127 torch::jit::Value* GetParameter(BackendDataPtr data);
128
129 std::shared_ptr<torch::jit::Graph> graph() const {
130 return graph_;
131 }
132
133 private:
134 struct Parameter {
135 torch::jit::Value* param{nullptr};
136 size_t index = 0;
137 };
138
139 size_t AddResult(torch::jit::Value* op) {
140 root_tuple_.push_back(std::move(op));
141 return root_tuple_.size() - 1;
142 }
143
144 std::shared_ptr<torch::jit::Graph> graph_;
145 std::shared_ptr<torch::jit::GraphFunction> function_;
146 std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
147 std::vector<torch::jit::Value*> root_tuple_;
148 OutputMap<torch::jit::Value*> emitted_outputs_;
149};
150
151} // namespace lazy
152} // namespace torch
153