1 | #include <torch/csrc/lazy/backend/backend_interface.h> |
2 | #include <torch/csrc/lazy/core/cache.h> |
3 | #include <torch/csrc/lazy/core/config.h> |
4 | #include <torch/csrc/lazy/core/ir.h> |
5 | #include <torch/csrc/lazy/core/ir_metadata.h> |
6 | |
7 | // Enables caching on for dynamic shapes (aka disable hash on shapes) |
8 | C10_DEFINE_bool( |
9 | ltc_enable_dynamic_shapes, |
10 | false, |
11 | "Whether dynamic shape is enabled" ); |
12 | |
13 | namespace torch { |
14 | namespace lazy { |
15 | |
16 | static const torch::lazy::Output kNullOutput = torch::lazy::Output(); |
17 | |
18 | size_t Output::Hasher::operator()(const Output& output) const { |
19 | return StdHashCombine( |
20 | reinterpret_cast<std::ptrdiff_t>(output.node), output.index); |
21 | } |
22 | |
23 | hash_t Output::hash() const { |
24 | return HashCombine(node->hash(), Hash(index)); |
25 | } |
26 | |
27 | hash_t Output::shapeHash() const { |
28 | return HashCombine(node->shapeHash(), Hash(index)); |
29 | } |
30 | |
31 | std::string Output::ToString() const { |
32 | std::stringstream ss; |
33 | ss << node->ToString() << ", index=" << index; |
34 | return ss.str(); |
35 | } |
36 | |
37 | bool Output::operator==(const Value& rhs) const { |
38 | // Either side could be kNullValue which has node as nullptr |
39 | return (!node == !rhs.node) && |
40 | (!node || (node->hash() == rhs.node->hash() && index == rhs.index)); |
41 | } |
42 | |
43 | hash_t Value::hash() const { |
44 | return HashCombine(node->hash(), Hash(index)); |
45 | } |
46 | |
47 | hash_t Value::shapeHash() const { |
48 | return HashCombine(node->shapeHash(), Hash(index)); |
49 | } |
50 | |
51 | OpKind OpKind::Get(const std::string& name) { |
52 | return OpKind(c10::Symbol::fromQualString(name)); |
53 | } |
54 | |
55 | hash_t OpKind::hash() const { |
56 | return StringHash(op.toQualString()); |
57 | } |
58 | |
59 | bool Node::enableDynamicShape() { |
60 | static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES" ) != nullptr; |
61 | return enabled || FLAGS_ltc_enable_dynamic_shapes; |
62 | } |
63 | |
64 | Node::Node(OpKind op, size_t num_outputs) |
65 | : op_(op), num_outputs_(num_outputs), metadata_(GetMetaDataIfDebugging()) {} |
66 | |
67 | Node::Node( |
68 | OpKind op, |
69 | OpList operands, |
70 | std::vector<Shape>&& shapes, |
71 | size_t num_outputs) |
72 | : Node(op, num_outputs) { |
73 | // Move shapes into node |
74 | shapes_.insert( |
75 | shapes_.end(), |
76 | std::make_move_iterator(shapes.begin()), |
77 | std::make_move_iterator(shapes.end())); |
78 | |
79 | for (auto& operand : operands) { |
80 | // Ideally, optional operands should be filtered by the leaf node classes, |
81 | // but it's just much easier to do it here. |
82 | // TODO(alanwaketan): Find a way to move the below logic to the leaf node |
83 | // classes. |
84 | if (!operand) { |
85 | continue; |
86 | } |
87 | |
88 | AddOperand(operand.node, operand.index); |
89 | } |
90 | } |
91 | |
92 | Node::Node( |
93 | OpKind op, |
94 | OpList operands, |
95 | const std::function<Shape()>& shape_fn, |
96 | size_t num_outputs) |
97 | : Node(op, operands, std::vector<Shape>{}, num_outputs) { |
98 | addComputedShape(shape_fn); |
99 | } |
100 | |
101 | Node::Node(OpKind op, OpList operands, size_t num_outputs) |
102 | : Node(op, operands, std::vector<Shape>{}, num_outputs) {} |
103 | |
104 | Node::Node(OpKind op, Shape shape, size_t num_outputs) : Node(op, num_outputs) { |
105 | shapes_.push_back(std::move(shape)); |
106 | } |
107 | |
108 | Node::~Node() = default; |
109 | |
110 | // Retrieves the full shape of the IR Node. |
111 | c10::ArrayRef<Shape> Node::shapes() const { |
112 | return shapes_; |
113 | } |
114 | |
115 | // Retrieves the shape of the output at a given index. |
116 | const Shape& Node::shape(size_t output_index) const { |
117 | return shapes_.at(output_index); |
118 | } |
119 | |
120 | // Add the shape computed by the shape_fn |
121 | |
122 | void Node::addComputedShape(const std::function<Shape()>& shape_fn) { |
123 | shapes_.push_back(computeShape(shape_fn)); |
124 | } |
125 | |
126 | using ShapeCache = Cache<hash_t, Shape, HashReducer>; |
127 | |
128 | // Compute the shape using the provided shape_fn. |
129 | Shape Node::computeShape(const std::function<Shape()>& shape_fn) { |
130 | static ShapeCache* cache = new ShapeCache(FLAGS_torch_lazy_shape_cache_size); |
131 | |
132 | auto hash = shapeHash(); |
133 | auto shape = cache->Get(hash); |
134 | if (shape == nullptr) { |
135 | shape = cache->Add(hash, std::make_shared<Shape>(shape_fn())); |
136 | } |
137 | return *shape; |
138 | } |
139 | |
140 | const std::vector<Output>& Node::operands() const { |
141 | return operands_as_outputs_; |
142 | } |
143 | |
144 | const Output& Node::operand(size_t i) const { |
145 | return operands_as_outputs_.at(i); |
146 | } |
147 | |
148 | const Output& Node::nullable_operand(size_t i) const { |
149 | // We use kNullOutput instead of kNullValue here to avoid implicit casting, |
150 | // which would prevent this method from returning a reference. |
151 | return i < operands_as_outputs_.size() ? operand(i) : kNullOutput; |
152 | } |
153 | |
154 | std::string Node::ToString() const { |
155 | std::stringstream ss; |
156 | ss << shapes() << " " << op(); |
157 | if (num_outputs() > 1) { |
158 | ss << ", num_outputs=" << num_outputs(); |
159 | } |
160 | if (!metadata().scope.empty()) { |
161 | ss << ", scope=" << metadata().scope; |
162 | } |
163 | EmitShortFrameInfo(ss, metadata().frame_info); |
164 | return ss.str(); |
165 | } |
166 | |
167 | void Node::AddOperand(NodePtr node, size_t index) { |
168 | TORCH_CHECK_LT(index, node->num_outputs()); |
169 | operands_.push_back(node); |
170 | operands_as_outputs_.emplace_back(operands_.back().get(), index); |
171 | } |
172 | |
173 | } // namespace lazy |
174 | } // namespace torch |
175 | |