1 | #pragma once |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | |
5 | #include <functional> |
6 | #include <memory> |
7 | #include <set> |
8 | #include <string> |
9 | #include <unordered_map> |
10 | #include <unordered_set> |
11 | #include <utility> |
12 | #include <vector> |
13 | |
14 | #include <c10/core/ScalarType.h> |
15 | #include <c10/util/ArrayRef.h> |
16 | #include <c10/util/Flags.h> |
17 | #include <torch/csrc/lazy/core/hash.h> |
18 | #include <torch/csrc/lazy/core/ir_metadata.h> |
19 | #include <torch/csrc/lazy/core/shape.h> |
20 | |
21 | C10_DECLARE_bool(ltc_enable_dynamic_shapes); |
22 | |
23 | namespace torch { |
24 | namespace lazy { |
25 | |
26 | static const hash_t kHashSeed(static_cast<uint32_t>(0x5a2d296e9)); |
27 | |
28 | class Node; |
29 | struct Output; |
30 | struct Value; |
31 | |
32 | using NodePtr = std::shared_ptr<Node>; |
33 | |
34 | // The Kind of operation a Node can be associated to. |
35 | struct TORCH_API OpKind { |
36 | OpKind() = default; |
37 | explicit OpKind(c10::Symbol op) : op(op) {} |
38 | |
39 | bool operator==(const OpKind& rhs) const { |
40 | return op == rhs.op; |
41 | } |
42 | bool operator!=(const OpKind& rhs) const { |
43 | return !operator==(rhs); |
44 | } |
45 | bool operator<(const OpKind& rhs) const { |
46 | return c10::unique_t(op) < c10::unique_t(rhs.op); |
47 | } |
48 | |
49 | hash_t hash() const; |
50 | |
51 | std::string ToString() const { |
52 | return op.toQualString(); |
53 | } |
54 | |
55 | // Retrieves an existing operation object, or creates a new one. Operations |
56 | // that are specific to lazy tensors, should live within the 'lazy_tensors::' |
57 | // namespace. |
58 | static OpKind Get(const std::string& name); |
59 | |
60 | c10::Symbol op; |
61 | }; |
62 | |
63 | inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) { |
64 | stream << op.ToString(); |
65 | return stream; |
66 | } |
67 | |
68 | using OpList = c10::ArrayRef<Value>; |
69 | |
70 | hash_t OperandHashes( |
71 | const OpList& operands, |
72 | const hash_t& seed, |
73 | bool bakeInSizes); |
74 | // A node in the graph. Nodes for operations which require extra data to be |
75 | // stored for lowering should inherit from this class and add an operation |
76 | // specific member there. For example, a constant might create a new |
77 | // NodeConstant class (inheriting from Node) with an extra lazy_tensors::Literal |
78 | // field, or a tensor value might create a new NodeTensor with a computation |
79 | // client data handle in it. |
80 | class TORCH_API Node { |
81 | public: |
82 | static bool enableDynamicShape(); |
83 | |
84 | // Creates a new node with the given op name. The op is a unique identifier |
85 | // for the operation. The num_outputs tells how many outputs a given operation |
86 | // generates. |
87 | // |
88 | // None leaf node's node_hash does not contains shape information always. |
89 | // So we pass in the hash value rather than a function. |
90 | Node(OpKind op, size_t num_outputs); |
91 | |
92 | // Construct node with operands and shapes |
93 | Node( |
94 | OpKind op, |
95 | OpList operands, |
96 | std::vector<Shape>&& shapes, |
97 | size_t num_outputs = 1); |
98 | |
99 | // Construct node with operands and shape generated from a function |
100 | Node( |
101 | OpKind op, |
102 | OpList operands, |
103 | const std::function<Shape()>& shape_fn, |
104 | size_t num_outputs = 1); |
105 | |
106 | // Construct node with operands and no shape |
107 | Node(OpKind op, OpList operands, size_t num_outputs = 1); |
108 | |
109 | // Construct node with shape and no operands |
110 | Node(OpKind op, Shape shape, size_t num_outputs = 1); |
111 | |
112 | virtual ~Node(); |
113 | |
114 | const OpKind& op() const { |
115 | return op_; |
116 | } |
117 | |
118 | size_t num_outputs() const { |
119 | return num_outputs_; |
120 | } |
121 | |
122 | // Retrieves the full shape of the IR Node. |
123 | virtual c10::ArrayRef<Shape> shapes() const; |
124 | |
125 | virtual const Shape& shape(size_t output_index = 0) const; |
126 | |
127 | // Add the shape computed by the shape_fn |
128 | void addComputedShape(const std::function<Shape()>& shape_fn); |
129 | |
130 | // Compute the shape using the provided shape_fn if not previously cached |
131 | Shape computeShape(const std::function<Shape()>& shape_fn); |
132 | |
133 | virtual const std::vector<Output>& operands() const; |
134 | |
135 | virtual const Output& operand(size_t i) const; |
136 | |
137 | // Gets operand at index i if index is valid, or kNullOutput otherwise. |
138 | virtual const Output& nullable_operand(size_t i) const; |
139 | |
140 | // Returns the hash of the dag used to look up the compiled graph |
141 | virtual hash_t hash() const = 0; |
142 | |
143 | // Returns the hash of the dag used to for shape caching |
144 | virtual hash_t shapeHash() const = 0; |
145 | |
146 | const MetaData& metadata() const { |
147 | return metadata_; |
148 | } |
149 | |
150 | UserMetaData* user_metadata() const { |
151 | return user_metadata_.get(); |
152 | } |
153 | |
154 | std::shared_ptr<UserMetaData> SetUserMetadata( |
155 | std::shared_ptr<UserMetaData> user_meta) { |
156 | std::swap(user_metadata_, user_meta); |
157 | return user_meta; |
158 | } |
159 | |
160 | virtual std::string ToString() const; |
161 | |
162 | private: |
163 | // The ID of the operation captured by this node. |
164 | OpKind op_; |
165 | size_t num_outputs_ = 1; |
166 | |
167 | // The IR specific metadata attached to the IR node. |
168 | MetaData metadata_; |
169 | // The IR framework user can attach a user defined metadata object deriving |
170 | // from UserMetaData. |
171 | std::shared_ptr<UserMetaData> user_metadata_; |
172 | |
173 | protected: |
174 | // Adds node's index output number as operand. |
175 | void AddOperand(NodePtr node, size_t index = 0); |
176 | |
177 | std::vector<Shape> shapes_; |
178 | // A node holds a real reference to its operands. |
179 | std::vector<NodePtr> operands_; |
180 | // Outputs do not hold references on the nodes, and neither do the uses, since |
181 | // otherwise we get into circular reference counting. |
182 | std::vector<Output> operands_as_outputs_; |
183 | }; |
184 | |
185 | inline std::ostream& operator<<(std::ostream& stream, const Node& node) { |
186 | stream << node.ToString(); |
187 | return stream; |
188 | } |
189 | |
190 | // Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and |
191 | // clean up once the migration is done. |
192 | template <typename T> |
193 | const T* NodeCast(const Node* node, OpKind op) { |
194 | if (op != node->op()) { |
195 | return nullptr; |
196 | } |
197 | #ifdef NDEBUG |
198 | return static_cast<const T*>(node); |
199 | #else |
200 | return &dynamic_cast<const T&>(*node); |
201 | #endif |
202 | } |
203 | |
204 | template <typename T> |
205 | const T* NodeCast(const Node* node) { |
206 | if (T::ClassOpKind() != node->op()) { |
207 | return nullptr; |
208 | } |
209 | // TODO: Some IR classes share the same opkind, such as Mean and MeanDim, so |
210 | // static_cast is not safe here. Unless we have opkind unique for each class, |
211 | // we have to use dynamic_cast here. |
212 | return dynamic_cast<const T*>(node); |
213 | } |
214 | |
215 | // Represents a specific output produced by a node. Since the output of a node |
216 | // can be composed by multiple outputs, the node+index coordinates fully qualify |
217 | // each single output. |
218 | struct TORCH_API Output { |
219 | struct Hasher { |
220 | size_t operator()(const Output& output) const; |
221 | }; |
222 | |
223 | Output() = default; |
224 | explicit Output(const Node* node, size_t index = 0) |
225 | : node(node), index(index) {} |
226 | |
227 | hash_t hash() const; |
228 | hash_t shapeHash() const; |
229 | |
230 | bool operator==(const Output& rhs) const { |
231 | return node == rhs.node && index == rhs.index; |
232 | } |
233 | |
234 | // To compare the operands of to-be-constructed node and to-be-reused node |
235 | bool operator==(const Value& rhs) const; |
236 | |
237 | bool operator!=(const Output& rhs) const { |
238 | return !operator==(rhs); |
239 | } |
240 | |
241 | const Shape& shape() const { |
242 | return node->shape(index); |
243 | } |
244 | |
245 | std::string ToString() const; |
246 | |
247 | // The node providing the output. |
248 | const Node* node{nullptr}; |
249 | // The index in the node's output this output refers to. |
250 | size_t index{0}; |
251 | }; |
252 | |
253 | inline std::ostream& operator<<(std::ostream& stream, const Output& output) { |
254 | stream << output.ToString(); |
255 | return stream; |
256 | } |
257 | |
258 | template <typename T> |
259 | using OutputMap = std::unordered_map<Output, T, Output::Hasher>; |
260 | |
261 | // Represents an input/operand for a Node object. |
262 | struct TORCH_API Value { |
263 | Value() = default; |
264 | /* implicit */ Value(NodePtr&& node, size_t index = 0) |
265 | : node(std::move(node)), index(index) {} |
266 | /* implicit */ Value(const NodePtr& node, size_t index = 0) |
267 | : node(node), index(index) {} |
268 | |
269 | hash_t hash() const; |
270 | hash_t shapeHash() const; |
271 | |
272 | operator bool() const { |
273 | return node != nullptr; |
274 | } |
275 | |
276 | operator Output() const { |
277 | return Output(node.get(), index); |
278 | } |
279 | |
280 | const Shape& shape() const { |
281 | return node->shape(index); |
282 | } |
283 | |
284 | Node* operator->() const { |
285 | return node.get(); |
286 | } |
287 | |
288 | NodePtr node; |
289 | size_t index = 0; |
290 | }; |
291 | |
292 | } // namespace lazy |
293 | } // namespace torch |
294 | |
295 | namespace c10 { |
296 | // Explicit template instantiation to make ArrayRef<Value> work |
297 | template class at::ArrayRef<torch::lazy::Value>; |
298 | } // namespace c10 |
299 | |