1 | #include <torch/csrc/lazy/ts_backend/ts_node_lowering.h> |
2 | |
3 | #include <ATen/Functions.h> |
4 | #include <torch/csrc/jit/frontend/sugared_value.h> |
5 | #include <torch/csrc/jit/jit_log.h> |
6 | #include <torch/csrc/lazy/backend/backend_interface.h> |
7 | #include <torch/csrc/lazy/core/helpers.h> |
8 | #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> |
9 | #include <torch/csrc/lazy/core/ir_builder.h> |
10 | #include <torch/csrc/lazy/core/lazy_graph_executor.h> |
11 | #include <torch/csrc/lazy/core/ops/utils.h> |
12 | #include <torch/csrc/lazy/core/permutation_util.h> |
13 | #include <torch/csrc/lazy/ts_backend/ir_builder.h> |
14 | #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> |
15 | |
16 | namespace torch { |
17 | namespace lazy { |
18 | |
19 | TSOpVector LowerBuiltin( |
20 | const torch::lazy::Node* node, |
21 | std::shared_ptr<torch::jit::GraphFunction> function, |
22 | const std::vector<torch::jit::NamedValue>& arguments, |
23 | const std::vector<torch::jit::NamedValue>& kwarguments = {}) { |
24 | return LowerTSBuiltin(function, node->op().op, arguments, kwarguments); |
25 | } |
26 | TSOpVector LowerBuiltin( |
27 | c10::Symbol sym, |
28 | std::shared_ptr<torch::jit::GraphFunction> function, |
29 | const std::vector<torch::jit::NamedValue>& arguments, |
30 | const std::vector<torch::jit::NamedValue>& kwarguments = {}) { |
31 | return LowerTSBuiltin(function, sym, arguments, kwarguments); |
32 | } |
33 | |
34 | TSOpVector LowerTSBuiltin( |
35 | std::shared_ptr<torch::jit::GraphFunction> function, |
36 | c10::Symbol sym, |
37 | const std::vector<torch::jit::NamedValue>& arguments, |
38 | const std::vector<torch::jit::NamedValue>& kwarguments) { |
39 | auto builtin = |
40 | std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt); |
41 | auto magic_method = std::make_shared<torch::jit::MagicMethod>("" , builtin); |
42 | auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); |
43 | auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get()); |
44 | CHECK(sv); |
45 | if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { |
46 | const auto tuple_call_result = sv->asTuple({}, *function); |
47 | TSOpVector tuple_result; |
48 | for (const auto& tuple_component : tuple_call_result) { |
49 | auto tuple_component_sv = |
50 | dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get()); |
51 | tuple_result.push_back(tuple_component_sv->getValue()); |
52 | } |
53 | return tuple_result; |
54 | } |
55 | return {sv->getValue()}; |
56 | } |
57 | |
58 | torch::jit::Value* GenerateClone( |
59 | torch::jit::Value* val, |
60 | std::shared_ptr<torch::jit::GraphFunction> function) { |
61 | std::vector<torch::jit::NamedValue> clone_arguments; |
62 | clone_arguments.emplace_back(val); |
63 | TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments); |
64 | TORCH_CHECK_EQ(cloned.size(), 1); |
65 | return cloned.front(); |
66 | } |
67 | |
68 | void GenerateCopy( |
69 | torch::jit::Value* destination, |
70 | torch::jit::Value* source, |
71 | std::shared_ptr<torch::jit::GraphFunction> function) { |
72 | std::vector<torch::jit::NamedValue> arguments; |
73 | arguments.emplace_back(destination); |
74 | arguments.emplace_back(source); |
75 | LowerBuiltin(at::aten::copy_, function, arguments); |
76 | } |
77 | |
78 | torch::jit::Value* GenerateSlice( |
79 | torch::jit::Value* base, |
80 | int64_t dim, |
81 | int64_t start, |
82 | int64_t end, |
83 | int64_t step, |
84 | std::shared_ptr<torch::jit::GraphFunction> function) { |
85 | std::vector<torch::jit::NamedValue> arguments; |
86 | arguments.emplace_back(base); |
87 | arguments.emplace_back(dim); |
88 | arguments.emplace_back(start); |
89 | arguments.emplace_back(end); |
90 | arguments.emplace_back(step); |
91 | TSOpVector selected = LowerBuiltin(at::aten::slice, function, arguments); |
92 | TORCH_CHECK_EQ(selected.size(), 1); |
93 | return selected.front(); |
94 | } |
95 | |
96 | // Node Lowerings |
97 | |
98 | // Default node lowering |
99 | TSOpVector TsNode::Lower( |
100 | std::shared_ptr<torch::jit::GraphFunction> function, |
101 | TSLoweringContext* loctx) const { |
102 | std::vector<torch::jit::NamedValue> arguments; |
103 | for (const torch::lazy::Output& output : operands()) { |
104 | arguments.emplace_back(loctx->GetOutputOp(output)); |
105 | } |
106 | return LowerBuiltin(this, function, arguments); |
107 | } |
108 | |
109 | // Non-native ops |
110 | torch::lazy::TSOpVector Cast::Lower( |
111 | std::shared_ptr<torch::jit::GraphFunction> function, |
112 | torch::lazy::TSLoweringContext* loctx) const { |
113 | std::vector<torch::jit::NamedValue> arguments; |
114 | arguments.emplace_back(loctx->GetOutputOp(operand(0))); |
115 | arguments.emplace_back(dtype); |
116 | return LowerBuiltin(at::aten::to, function, arguments); |
117 | } |
118 | |
119 | torch::lazy::TSOpVector DeviceData::Lower( |
120 | std::shared_ptr<torch::jit::GraphFunction> function, |
121 | torch::lazy::TSLoweringContext* loctx) const { |
122 | auto infoptr = data_->info(); |
123 | auto deviceDataInfoPtr = |
124 | (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; |
125 | if (GRAPH_DUMP_ENABLED) { |
126 | LOG(ERROR) << "Lowering device data node, tensor id " |
127 | << deviceDataInfoPtr->tensor_id << std::endl; |
128 | } |
129 | return {loctx->GetParameter(data_)}; |
130 | } |
131 | |
132 | torch::lazy::TSOpVector Expand::Lower( |
133 | std::shared_ptr<torch::jit::GraphFunction> function, |
134 | torch::lazy::TSLoweringContext* loctx) const { |
135 | std::vector<torch::jit::NamedValue> arguments; |
136 | arguments.emplace_back(loctx->GetOutputOp(operand(0))); |
137 | arguments.emplace_back(size); |
138 | auto expand_out = LowerBuiltin(this, function, arguments); |
139 | if (is_scalar_expand) { |
140 | // The aten::expand operations sets all strides to 0 when the original is |
141 | // of rank 0. This leads to false positives when checking for internal |
142 | // memory overlap, because at::has_internal_overlap returns |
143 | // MemOverlap::YES when a stride is set to 0. |
144 | TORCH_CHECK_EQ(expand_out.size(), 1); |
145 | return {GenerateClone(expand_out.front(), function)}; |
146 | } |
147 | return expand_out; |
148 | } |
149 | |
150 | torch::lazy::TSOpVector Scalar::Lower( |
151 | std::shared_ptr<torch::jit::GraphFunction> function, |
152 | torch::lazy::TSLoweringContext* loctx) const { |
153 | auto options = |
154 | at::TensorOptions() |
155 | .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) |
156 | .dtype(shape().scalar_type()); |
157 | return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; |
158 | } |
159 | |
160 | } // namespace lazy |
161 | } // namespace torch |
162 | |