1#include <torch/csrc/lazy/backend/backend_interface.h>
2#include <torch/csrc/lazy/core/cache.h>
3#include <torch/csrc/lazy/core/config.h>
4#include <torch/csrc/lazy/core/ir.h>
5#include <torch/csrc/lazy/core/ir_metadata.h>
6
7// Enables caching on for dynamic shapes (aka disable hash on shapes)
8C10_DEFINE_bool(
9 ltc_enable_dynamic_shapes,
10 false,
11 "Whether dynamic shape is enabled");
12
13namespace torch {
14namespace lazy {
15
16static const torch::lazy::Output kNullOutput = torch::lazy::Output();
17
18size_t Output::Hasher::operator()(const Output& output) const {
19 return StdHashCombine(
20 reinterpret_cast<std::ptrdiff_t>(output.node), output.index);
21}
22
23hash_t Output::hash() const {
24 return HashCombine(node->hash(), Hash(index));
25}
26
27hash_t Output::shapeHash() const {
28 return HashCombine(node->shapeHash(), Hash(index));
29}
30
31std::string Output::ToString() const {
32 std::stringstream ss;
33 ss << node->ToString() << ", index=" << index;
34 return ss.str();
35}
36
37bool Output::operator==(const Value& rhs) const {
38 // Either side could be kNullValue which has node as nullptr
39 return (!node == !rhs.node) &&
40 (!node || (node->hash() == rhs.node->hash() && index == rhs.index));
41}
42
43hash_t Value::hash() const {
44 return HashCombine(node->hash(), Hash(index));
45}
46
47hash_t Value::shapeHash() const {
48 return HashCombine(node->shapeHash(), Hash(index));
49}
50
51OpKind OpKind::Get(const std::string& name) {
52 return OpKind(c10::Symbol::fromQualString(name));
53}
54
55hash_t OpKind::hash() const {
56 return StringHash(op.toQualString());
57}
58
59bool Node::enableDynamicShape() {
60 static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr;
61 return enabled || FLAGS_ltc_enable_dynamic_shapes;
62}
63
64Node::Node(OpKind op, size_t num_outputs)
65 : op_(op), num_outputs_(num_outputs), metadata_(GetMetaDataIfDebugging()) {}
66
67Node::Node(
68 OpKind op,
69 OpList operands,
70 std::vector<Shape>&& shapes,
71 size_t num_outputs)
72 : Node(op, num_outputs) {
73 // Move shapes into node
74 shapes_.insert(
75 shapes_.end(),
76 std::make_move_iterator(shapes.begin()),
77 std::make_move_iterator(shapes.end()));
78
79 for (auto& operand : operands) {
80 // Ideally, optional operands should be filtered by the leaf node classes,
81 // but it's just much easier to do it here.
82 // TODO(alanwaketan): Find a way to move the below logic to the leaf node
83 // classes.
84 if (!operand) {
85 continue;
86 }
87
88 AddOperand(operand.node, operand.index);
89 }
90}
91
92Node::Node(
93 OpKind op,
94 OpList operands,
95 const std::function<Shape()>& shape_fn,
96 size_t num_outputs)
97 : Node(op, operands, std::vector<Shape>{}, num_outputs) {
98 addComputedShape(shape_fn);
99}
100
101Node::Node(OpKind op, OpList operands, size_t num_outputs)
102 : Node(op, operands, std::vector<Shape>{}, num_outputs) {}
103
104Node::Node(OpKind op, Shape shape, size_t num_outputs) : Node(op, num_outputs) {
105 shapes_.push_back(std::move(shape));
106}
107
108Node::~Node() = default;
109
110// Retrieves the full shape of the IR Node.
111c10::ArrayRef<Shape> Node::shapes() const {
112 return shapes_;
113}
114
115// Retrieves the shape of the output at a given index.
116const Shape& Node::shape(size_t output_index) const {
117 return shapes_.at(output_index);
118}
119
120// Add the shape computed by the shape_fn
121
122void Node::addComputedShape(const std::function<Shape()>& shape_fn) {
123 shapes_.push_back(computeShape(shape_fn));
124}
125
126using ShapeCache = Cache<hash_t, Shape, HashReducer>;
127
128// Compute the shape using the provided shape_fn.
129Shape Node::computeShape(const std::function<Shape()>& shape_fn) {
130 static ShapeCache* cache = new ShapeCache(FLAGS_torch_lazy_shape_cache_size);
131
132 auto hash = shapeHash();
133 auto shape = cache->Get(hash);
134 if (shape == nullptr) {
135 shape = cache->Add(hash, std::make_shared<Shape>(shape_fn()));
136 }
137 return *shape;
138}
139
140const std::vector<Output>& Node::operands() const {
141 return operands_as_outputs_;
142}
143
144const Output& Node::operand(size_t i) const {
145 return operands_as_outputs_.at(i);
146}
147
148const Output& Node::nullable_operand(size_t i) const {
149 // We use kNullOutput instead of kNullValue here to avoid implicit casting,
150 // which would prevent this method from returning a reference.
151 return i < operands_as_outputs_.size() ? operand(i) : kNullOutput;
152}
153
154std::string Node::ToString() const {
155 std::stringstream ss;
156 ss << shapes() << " " << op();
157 if (num_outputs() > 1) {
158 ss << ", num_outputs=" << num_outputs();
159 }
160 if (!metadata().scope.empty()) {
161 ss << ", scope=" << metadata().scope;
162 }
163 EmitShortFrameInfo(ss, metadata().frame_info);
164 return ss.str();
165}
166
167void Node::AddOperand(NodePtr node, size_t index) {
168 TORCH_CHECK_LT(index, node->num_outputs());
169 operands_.push_back(node);
170 operands_as_outputs_.emplace_back(operands_.back().get(), index);
171}
172
173} // namespace lazy
174} // namespace torch
175