1#include <torch/csrc/lazy/core/debug_util.h>
2#include <torch/csrc/lazy/ts_backend/ts_node.h>
3
4namespace {
5std::string GetFirstUserFrameInPythonIfEnabled() {
6 static const auto LTC_ENABLE_SOURCE_INFO =
7 std::getenv("LTC_ENABLE_SOURCE_INFO");
8 if (!LTC_ENABLE_SOURCE_INFO) {
9 return {};
10 }
11
12 return torch::lazy::GetFirstUserFrameInPython();
13}
14} // namespace
15
16namespace torch {
17namespace lazy {
18
19hash_t OperandHashes(
20 const OpList& operands,
21 const c10::ArrayRef<Shape>& shapes,
22 const hash_t& seed,
23 bool bakeInSizes) {
24 hash_t hash = seed;
25 for (auto& operand : operands) {
26 if (!operand) {
27 hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
28 continue;
29 }
30 auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
31 hash = HashCombine(hash, operand_hash);
32 }
33 for (auto& shape : shapes) {
34 hash = HashCombine(hash, shape.hash(bakeInSizes));
35 }
36 return hash;
37}
38
39TsNode::TsNode(
40 OpKind op,
41 OpList operands,
42 std::vector<Shape>&& shapes,
43 size_t num_outputs,
44 hash_t hash_seed)
45 : Node(op, operands, std::move(shapes), num_outputs) {
46 hash_seed = HashCombine(op.hash(), hash_seed);
47 shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
48 dag_hash_ =
49 (enableDynamicShape()
50 ? OperandHashes(operands, this->shapes(), hash_seed, false)
51 : shape_hash_);
52}
53
54TsNode::TsNode(
55 OpKind op,
56 OpList operands,
57 const std::function<Shape()>& shape_fn,
58 size_t num_outputs,
59 hash_t hash_seed)
60 : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
61 addComputedShape(shape_fn);
62}
63
64TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
65 : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
66
67TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
68 : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
69
70hash_t TsNode::hash() const {
71 return dag_hash_;
72}
73
74hash_t TsNode::shapeHash() const {
75 return shape_hash_;
76}
77
78const std::string TsNode::getPythonStacktrace() const {
79 return GetFirstUserFrameInPythonIfEnabled();
80}
81
82TensorList::TensorList(OpList values)
83 : TsNode(
84 /*op=*/ClassOpKind(),
85 /*operands=*/values,
86 /*shapes=*/std::vector<Shape>(),
87 /*num_outputs=*/1,
88 /*hash_seed=*/kHashSeed) {}
89
90TSOpVector TensorList::Lower(
91 std::shared_ptr<torch::jit::GraphFunction> function,
92 TSLoweringContext* loctx) const {
93 std::vector<torch::jit::Value*> tensor_list;
94 CHECK(!operands().empty());
95 for (const torch::lazy::Output& operand : operands()) {
96 tensor_list.emplace_back(loctx->GetOutputOp(operand));
97 }
98 auto graph = function->graph();
99 auto listnode =
100 graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list));
101 return {listnode->output()};
102}
103
104} // namespace lazy
105} // namespace torch
106