1 | #include <c10/core/ScalarType.h> |
2 | #include <torch/csrc/lazy/ts_backend/ts_backend_impl.h> |
3 | #include <torch/csrc/lazy/ts_backend/ts_lowering_context.h> |
4 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
5 | |
6 | namespace torch { |
7 | namespace lazy { |
8 | |
9 | TSLoweringContext::TSLoweringContext( |
10 | const std::string& name, |
11 | BackendDevice device) |
12 | : torch::lazy::LoweringContext(name, device), |
13 | graph_(std::make_shared<torch::jit::Graph>()), |
14 | function_( |
15 | std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) {} |
16 | |
17 | TSLoweringContext::TSLoweringContext( |
18 | const std::string& name, |
19 | BackendDevice device, |
20 | c10::ArrayRef<const Node*> post_order, |
21 | Util::EmissionMap emit_status) |
22 | : torch::lazy::LoweringContext(name, device, post_order, emit_status), |
23 | graph_(std::make_shared<torch::jit::Graph>()), |
24 | function_( |
25 | std::make_shared<torch::jit::GraphFunction>(name, graph_, nullptr)) { |
26 | for (auto node : post_order) { |
27 | Lower(node); |
28 | } |
29 | } |
30 | |
31 | void TSLoweringContext::Lower(const Node* node) { |
32 | if (auto* tsnode = dynamic_cast<const torch::lazy::TsNode*>(node)) { |
33 | // First, we call the node lowering function, which exists for newly |
34 | // codegenned or refactored nodes |
35 | TSOpVector ops = tsnode->Lower(function_, this); |
36 | CHECK(!ops.empty()) << "Failed to lower: " << *node; |
37 | TORCH_CHECK_EQ(node->num_outputs(), ops.size()); |
38 | for (size_t i = 0; i < ops.size(); ++i) { |
39 | AssignOutputOp(torch::lazy::Output(node, i), ops[i]); |
40 | } |
41 | } else { |
42 | throw std::runtime_error( |
43 | "Expected torch::lazy::TsNode but could not dynamic cast" ); |
44 | } |
45 | } |
46 | |
47 | void TSLoweringContext::AssignOutputOp( |
48 | const Output& output, |
49 | torch::jit::Value* op) { |
50 | const TsNode* ts_node = static_cast<const TsNode*>(output.node); |
51 | std::string stack_trace = ts_node->getPythonStacktrace(); |
52 | if (!stack_trace.empty()) { |
53 | op->node()->s_(c10::Symbol::attr("source" ), stack_trace); |
54 | } |
55 | emitted_outputs_[output] = op; |
56 | } |
57 | |
58 | torch::jit::Value* TSLoweringContext::GetParameter(BackendDataPtr data) { |
59 | const auto ts_data = std::static_pointer_cast<TSData>(data); |
60 | BackendData::Handle handle = ts_data->GetHandle(); |
61 | auto it = parameters_map_.find(handle); |
62 | if (it == parameters_map_.end()) { |
63 | torch::jit::Value* param = |
64 | graph_->addInput(c10::str("p" , parameters_.size())); |
65 | if (ts_data->scalar.has_value()) { |
66 | auto scalarType = ts_data->scalar.value().type(); |
67 | if (isFloatingType(scalarType)) { |
68 | param->setType(c10::FloatType::get()); |
69 | } else if (isIntegralType(scalarType, /*includeBool=*/true)) { |
70 | param->setType(c10::IntType::get()); |
71 | } else { |
72 | TORCH_CHECK( |
73 | false, "Unhandled scalar type: " , c10::toString(scalarType)); |
74 | } |
75 | } |
76 | it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) |
77 | .first; |
78 | parameters_.push_back(ts_data); |
79 | } |
80 | parameter_sequence_.push_back(it->second.index); |
81 | return it->second.param; |
82 | } |
83 | |
84 | } // namespace lazy |
85 | } // namespace torch |
86 | |