1#include <c10/core/ScalarType.h>
2#include <torch/csrc/lazy/ts_backend/ts_backend_impl.h>
3#include <torch/csrc/lazy/ts_backend/ts_lowering_context.h>
4#include <torch/csrc/lazy/ts_backend/ts_node.h>
5
6namespace torch {
7namespace lazy {
8
9TSLoweringContext::TSLoweringContext(
10 const std::string& name,
11 BackendDevice device)
12 : torch::lazy::LoweringContext(name, device),
13 graph_(std::make_shared<torch::jit::Graph>()),
14 function_(
15 std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {}
16
17TSLoweringContext::TSLoweringContext(
18 const std::string& name,
19 BackendDevice device,
20 c10::ArrayRef<const Node*> post_order,
21 Util::EmissionMap emit_status)
22 : torch::lazy::LoweringContext(name, device, post_order, emit_status),
23 graph_(std::make_shared<torch::jit::Graph>()),
24 function_(
25 std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {
26 for (auto node : post_order) {
27 Lower(node);
28 }
29}
30
31void TSLoweringContext::Lower(const Node* node) {
32 if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) {
33 // First, we call the node lowering function, which exists for newly
34 // codegenned or refactored nodes
35 TSOpVector ops = tsnode->Lower(function_, this);
36 CHECK(!ops.empty()) << "Failed to lower: " << *node;
37 TORCH_CHECK_EQ(node->num_outputs(), ops.size());
38 for (size_t i = 0; i < ops.size(); ++i) {
39 AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
40 }
41 } else {
42 throw std::runtime_error(
43 "Expected torch::lazy::TsNode but could not dynamic cast");
44 }
45}
46
47void TSLoweringContext::AssignOutputOp(
48 const Output& output,
49 torch::jit::Value* op) {
50 const TsNode* ts_node = static_cast<const TsNode*>(output.node);
51 std::string stack_trace = ts_node->getPythonStacktrace();
52 if (!stack_trace.empty()) {
53 op->node()->s_(c10::Symbol::attr("source"), stack_trace);
54 }
55 emitted_outputs_[output] = op;
56}
57
58torch::jit::Value* TSLoweringContext::GetParameter(BackendDataPtr data) {
59 const auto ts_data = std::static_pointer_cast<TSData>(data);
60 BackendData::Handle handle = ts_data->GetHandle();
61 auto it = parameters_map_.find(handle);
62 if (it == parameters_map_.end()) {
63 torch::jit::Value* param =
64 graph_->addInput(c10::str("p", parameters_.size()));
65 if (ts_data->scalar.has_value()) {
66 auto scalarType = ts_data->scalar.value().type();
67 if (isFloatingType(scalarType)) {
68 param->setType(c10::FloatType::get());
69 } else if (isIntegralType(scalarType, /*includeBool=*/true)) {
70 param->setType(c10::IntType::get());
71 } else {
72 TORCH_CHECK(
73 false, "Unhandled scalar type: ", c10::toString(scalarType));
74 }
75 }
76 it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
77 .first;
78 parameters_.push_back(ts_data);
79 }
80 parameter_sequence_.push_back(it->second.index);
81 return it->second.param;
82}
83
84} // namespace lazy
85} // namespace torch
86