1 | #include <torch/csrc/lazy/core/debug_util.h> |
2 | #include <torch/csrc/lazy/ts_backend/ts_node.h> |
3 | |
4 | namespace { |
5 | std::string GetFirstUserFrameInPythonIfEnabled() { |
6 | static const auto LTC_ENABLE_SOURCE_INFO = |
7 | std::getenv("LTC_ENABLE_SOURCE_INFO" ); |
8 | if (!LTC_ENABLE_SOURCE_INFO) { |
9 | return {}; |
10 | } |
11 | |
12 | return torch::lazy::GetFirstUserFrameInPython(); |
13 | } |
14 | } // namespace |
15 | |
16 | namespace torch { |
17 | namespace lazy { |
18 | |
19 | hash_t OperandHashes( |
20 | const OpList& operands, |
21 | const c10::ArrayRef<Shape>& shapes, |
22 | const hash_t& seed, |
23 | bool bakeInSizes) { |
24 | hash_t hash = seed; |
25 | for (auto& operand : operands) { |
26 | if (!operand) { |
27 | hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt)); |
28 | continue; |
29 | } |
30 | auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); |
31 | hash = HashCombine(hash, operand_hash); |
32 | } |
33 | for (auto& shape : shapes) { |
34 | hash = HashCombine(hash, shape.hash(bakeInSizes)); |
35 | } |
36 | return hash; |
37 | } |
38 | |
39 | TsNode::TsNode( |
40 | OpKind op, |
41 | OpList operands, |
42 | std::vector<Shape>&& shapes, |
43 | size_t num_outputs, |
44 | hash_t hash_seed) |
45 | : Node(op, operands, std::move(shapes), num_outputs) { |
46 | hash_seed = HashCombine(op.hash(), hash_seed); |
47 | shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); |
48 | dag_hash_ = |
49 | (enableDynamicShape() |
50 | ? OperandHashes(operands, this->shapes(), hash_seed, false) |
51 | : shape_hash_); |
52 | } |
53 | |
54 | TsNode::TsNode( |
55 | OpKind op, |
56 | OpList operands, |
57 | const std::function<Shape()>& shape_fn, |
58 | size_t num_outputs, |
59 | hash_t hash_seed) |
60 | : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) { |
61 | addComputedShape(shape_fn); |
62 | } |
63 | |
64 | TsNode::TsNode(OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) |
65 | : TsNode(op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {} |
66 | |
67 | TsNode::TsNode(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) |
68 | : TsNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} |
69 | |
70 | hash_t TsNode::hash() const { |
71 | return dag_hash_; |
72 | } |
73 | |
74 | hash_t TsNode::shapeHash() const { |
75 | return shape_hash_; |
76 | } |
77 | |
78 | const std::string TsNode::getPythonStacktrace() const { |
79 | return GetFirstUserFrameInPythonIfEnabled(); |
80 | } |
81 | |
82 | TensorList::TensorList(OpList values) |
83 | : TsNode( |
84 | /*op=*/ClassOpKind(), |
85 | /*operands=*/values, |
86 | /*shapes=*/std::vector<Shape>(), |
87 | /*num_outputs=*/1, |
88 | /*hash_seed=*/kHashSeed) {} |
89 | |
90 | TSOpVector TensorList::Lower( |
91 | std::shared_ptr<torch::jit::GraphFunction> function, |
92 | TSLoweringContext* loctx) const { |
93 | std::vector<torch::jit::Value*> tensor_list; |
94 | CHECK(!operands().empty()); |
95 | for (const torch::lazy::Output& operand : operands()) { |
96 | tensor_list.emplace_back(loctx->GetOutputOp(operand)); |
97 | } |
98 | auto graph = function->graph(); |
99 | auto listnode = |
100 | graph->insertNode(graph->createList(tensor_list[0]->type(), tensor_list)); |
101 | return {listnode->output()}; |
102 | } |
103 | |
104 | } // namespace lazy |
105 | } // namespace torch |
106 | |