1#include <torch/csrc/lazy/ts_backend/ts_node_lowering.h>
2
3#include <ATen/Functions.h>
4#include <torch/csrc/jit/frontend/sugared_value.h>
5#include <torch/csrc/jit/jit_log.h>
6#include <torch/csrc/lazy/backend/backend_interface.h>
7#include <torch/csrc/lazy/core/helpers.h>
8#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
9#include <torch/csrc/lazy/core/ir_builder.h>
10#include <torch/csrc/lazy/core/lazy_graph_executor.h>
11#include <torch/csrc/lazy/core/ops/utils.h>
12#include <torch/csrc/lazy/core/permutation_util.h>
13#include <torch/csrc/lazy/ts_backend/ir_builder.h>
14#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
15
16namespace torch {
17namespace lazy {
18
19TSOpVector LowerBuiltin(
20 const torch::lazy::Node* node,
21 std::shared_ptr<torch::jit::GraphFunction> function,
22 const std::vector<torch::jit::NamedValue>& arguments,
23 const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
24 return LowerTSBuiltin(function, node->op().op, arguments, kwarguments);
25}
26TSOpVector LowerBuiltin(
27 c10::Symbol sym,
28 std::shared_ptr<torch::jit::GraphFunction> function,
29 const std::vector<torch::jit::NamedValue>& arguments,
30 const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
31 return LowerTSBuiltin(function, sym, arguments, kwarguments);
32}
33
34TSOpVector LowerTSBuiltin(
35 std::shared_ptr<torch::jit::GraphFunction> function,
36 c10::Symbol sym,
37 const std::vector<torch::jit::NamedValue>& arguments,
38 const std::vector<torch::jit::NamedValue>& kwarguments) {
39 auto builtin =
40 std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
41 auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
42 auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
43 auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
44 CHECK(sv);
45 if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
46 const auto tuple_call_result = sv->asTuple({}, *function);
47 TSOpVector tuple_result;
48 for (const auto& tuple_component : tuple_call_result) {
49 auto tuple_component_sv =
50 dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
51 tuple_result.push_back(tuple_component_sv->getValue());
52 }
53 return tuple_result;
54 }
55 return {sv->getValue()};
56}
57
58torch::jit::Value* GenerateClone(
59 torch::jit::Value* val,
60 std::shared_ptr<torch::jit::GraphFunction> function) {
61 std::vector<torch::jit::NamedValue> clone_arguments;
62 clone_arguments.emplace_back(val);
63 TSOpVector cloned = LowerBuiltin(at::aten::clone, function, clone_arguments);
64 TORCH_CHECK_EQ(cloned.size(), 1);
65 return cloned.front();
66}
67
68void GenerateCopy(
69 torch::jit::Value* destination,
70 torch::jit::Value* source,
71 std::shared_ptr<torch::jit::GraphFunction> function) {
72 std::vector<torch::jit::NamedValue> arguments;
73 arguments.emplace_back(destination);
74 arguments.emplace_back(source);
75 LowerBuiltin(at::aten::copy_, function, arguments);
76}
77
78torch::jit::Value* GenerateSlice(
79 torch::jit::Value* base,
80 int64_t dim,
81 int64_t start,
82 int64_t end,
83 int64_t step,
84 std::shared_ptr<torch::jit::GraphFunction> function) {
85 std::vector<torch::jit::NamedValue> arguments;
86 arguments.emplace_back(base);
87 arguments.emplace_back(dim);
88 arguments.emplace_back(start);
89 arguments.emplace_back(end);
90 arguments.emplace_back(step);
91 TSOpVector selected = LowerBuiltin(at::aten::slice, function, arguments);
92 TORCH_CHECK_EQ(selected.size(), 1);
93 return selected.front();
94}
95
96// Node Lowerings
97
98// Default node lowering
99TSOpVector TsNode::Lower(
100 std::shared_ptr<torch::jit::GraphFunction> function,
101 TSLoweringContext* loctx) const {
102 std::vector<torch::jit::NamedValue> arguments;
103 for (const torch::lazy::Output& output : operands()) {
104 arguments.emplace_back(loctx->GetOutputOp(output));
105 }
106 return LowerBuiltin(this, function, arguments);
107}
108
109// Non-native ops
110torch::lazy::TSOpVector Cast::Lower(
111 std::shared_ptr<torch::jit::GraphFunction> function,
112 torch::lazy::TSLoweringContext* loctx) const {
113 std::vector<torch::jit::NamedValue> arguments;
114 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
115 arguments.emplace_back(dtype);
116 return LowerBuiltin(at::aten::to, function, arguments);
117}
118
119torch::lazy::TSOpVector DeviceData::Lower(
120 std::shared_ptr<torch::jit::GraphFunction> function,
121 torch::lazy::TSLoweringContext* loctx) const {
122 auto infoptr = data_->info();
123 auto deviceDataInfoPtr =
124 (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
125 if (GRAPH_DUMP_ENABLED) {
126 LOG(ERROR) << "Lowering device data node, tensor id "
127 << deviceDataInfoPtr->tensor_id << std::endl;
128 }
129 return {loctx->GetParameter(data_)};
130}
131
132torch::lazy::TSOpVector Expand::Lower(
133 std::shared_ptr<torch::jit::GraphFunction> function,
134 torch::lazy::TSLoweringContext* loctx) const {
135 std::vector<torch::jit::NamedValue> arguments;
136 arguments.emplace_back(loctx->GetOutputOp(operand(0)));
137 arguments.emplace_back(size);
138 auto expand_out = LowerBuiltin(this, function, arguments);
139 if (is_scalar_expand) {
140 // The aten::expand operations sets all strides to 0 when the original is
141 // of rank 0. This leads to false positives when checking for internal
142 // memory overlap, because at::has_internal_overlap returns
143 // MemOverlap::YES when a stride is set to 0.
144 TORCH_CHECK_EQ(expand_out.size(), 1);
145 return {GenerateClone(expand_out.front(), function)};
146 }
147 return expand_out;
148}
149
150torch::lazy::TSOpVector Scalar::Lower(
151 std::shared_ptr<torch::jit::GraphFunction> function,
152 torch::lazy::TSLoweringContext* loctx) const {
153 auto options =
154 at::TensorOptions()
155 .device(torch::lazy::getBackend()->EagerFallbackDeviceType())
156 .dtype(shape().scalar_type());
157 return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
158}
159
160} // namespace lazy
161} // namespace torch
162