1#pragma once
2
3#include <ATen/core/symbol.h>
4
5#include <functional>
6#include <memory>
7#include <set>
8#include <string>
9#include <unordered_map>
10#include <unordered_set>
11#include <utility>
12#include <vector>
13
14#include <c10/core/ScalarType.h>
15#include <c10/util/ArrayRef.h>
16#include <c10/util/Flags.h>
17#include <torch/csrc/lazy/core/hash.h>
18#include <torch/csrc/lazy/core/ir_metadata.h>
19#include <torch/csrc/lazy/core/shape.h>
20
21C10_DECLARE_bool(ltc_enable_dynamic_shapes);
22
23namespace torch {
24namespace lazy {
25
26static const hash_t kHashSeed(static_cast<uint32_t>(0x5a2d296e9));
27
28class Node;
29struct Output;
30struct Value;
31
32using NodePtr = std::shared_ptr<Node>;
33
34// The Kind of operation a Node can be associated to.
35struct TORCH_API OpKind {
36 OpKind() = default;
37 explicit OpKind(c10::Symbol op) : op(op) {}
38
39 bool operator==(const OpKind& rhs) const {
40 return op == rhs.op;
41 }
42 bool operator!=(const OpKind& rhs) const {
43 return !operator==(rhs);
44 }
45 bool operator<(const OpKind& rhs) const {
46 return c10::unique_t(op) < c10::unique_t(rhs.op);
47 }
48
49 hash_t hash() const;
50
51 std::string ToString() const {
52 return op.toQualString();
53 }
54
55 // Retrieves an existing operation object, or creates a new one. Operations
56 // that are specific to lazy tensors, should live within the 'lazy_tensors::'
57 // namespace.
58 static OpKind Get(const std::string& name);
59
60 c10::Symbol op;
61};
62
63inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) {
64 stream << op.ToString();
65 return stream;
66}
67
68using OpList = c10::ArrayRef<Value>;
69
70hash_t OperandHashes(
71 const OpList& operands,
72 const hash_t& seed,
73 bool bakeInSizes);
74// A node in the graph. Nodes for operations which require extra data to be
75// stored for lowering should inherit from this class and add an operation
76// specific member there. For example, a constant might create a new
77// NodeConstant class (inheriting from Node) with an extra lazy_tensors::Literal
78// field, or a tensor value might create a new NodeTensor with a computation
79// client data handle in it.
80class TORCH_API Node {
81 public:
82 static bool enableDynamicShape();
83
84 // Creates a new node with the given op name. The op is a unique identifier
85 // for the operation. The num_outputs tells how many outputs a given operation
86 // generates.
87 //
88 // None leaf node's node_hash does not contains shape information always.
89 // So we pass in the hash value rather than a function.
90 Node(OpKind op, size_t num_outputs);
91
92 // Construct node with operands and shapes
93 Node(
94 OpKind op,
95 OpList operands,
96 std::vector<Shape>&& shapes,
97 size_t num_outputs = 1);
98
99 // Construct node with operands and shape generated from a function
100 Node(
101 OpKind op,
102 OpList operands,
103 const std::function<Shape()>& shape_fn,
104 size_t num_outputs = 1);
105
106 // Construct node with operands and no shape
107 Node(OpKind op, OpList operands, size_t num_outputs = 1);
108
109 // Construct node with shape and no operands
110 Node(OpKind op, Shape shape, size_t num_outputs = 1);
111
112 virtual ~Node();
113
114 const OpKind& op() const {
115 return op_;
116 }
117
118 size_t num_outputs() const {
119 return num_outputs_;
120 }
121
122 // Retrieves the full shape of the IR Node.
123 virtual c10::ArrayRef<Shape> shapes() const;
124
125 virtual const Shape& shape(size_t output_index = 0) const;
126
127 // Add the shape computed by the shape_fn
128 void addComputedShape(const std::function<Shape()>& shape_fn);
129
130 // Compute the shape using the provided shape_fn if not previously cached
131 Shape computeShape(const std::function<Shape()>& shape_fn);
132
133 virtual const std::vector<Output>& operands() const;
134
135 virtual const Output& operand(size_t i) const;
136
137 // Gets operand at index i if index is valid, or kNullOutput otherwise.
138 virtual const Output& nullable_operand(size_t i) const;
139
140 // Returns the hash of the dag used to look up the compiled graph
141 virtual hash_t hash() const = 0;
142
143 // Returns the hash of the dag used to for shape caching
144 virtual hash_t shapeHash() const = 0;
145
146 const MetaData& metadata() const {
147 return metadata_;
148 }
149
150 UserMetaData* user_metadata() const {
151 return user_metadata_.get();
152 }
153
154 std::shared_ptr<UserMetaData> SetUserMetadata(
155 std::shared_ptr<UserMetaData> user_meta) {
156 std::swap(user_metadata_, user_meta);
157 return user_meta;
158 }
159
160 virtual std::string ToString() const;
161
162 private:
163 // The ID of the operation captured by this node.
164 OpKind op_;
165 size_t num_outputs_ = 1;
166
167 // The IR specific metadata attached to the IR node.
168 MetaData metadata_;
169 // The IR framework user can attach a user defined metadata object deriving
170 // from UserMetaData.
171 std::shared_ptr<UserMetaData> user_metadata_;
172
173 protected:
174 // Adds node's index output number as operand.
175 void AddOperand(NodePtr node, size_t index = 0);
176
177 std::vector<Shape> shapes_;
178 // A node holds a real reference to its operands.
179 std::vector<NodePtr> operands_;
180 // Outputs do not hold references on the nodes, and neither do the uses, since
181 // otherwise we get into circular reference counting.
182 std::vector<Output> operands_as_outputs_;
183};
184
185inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
186 stream << node.ToString();
187 return stream;
188}
189
190// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
191// clean up once the migration is done.
192template <typename T>
193const T* NodeCast(const Node* node, OpKind op) {
194 if (op != node->op()) {
195 return nullptr;
196 }
197#ifdef NDEBUG
198 return static_cast<const T*>(node);
199#else
200 return &dynamic_cast<const T&>(*node);
201#endif
202}
203
204template <typename T>
205const T* NodeCast(const Node* node) {
206 if (T::ClassOpKind() != node->op()) {
207 return nullptr;
208 }
209 // TODO: Some IR classes share the same opkind, such as Mean and MeanDim, so
210 // static_cast is not safe here. Unless we have opkind unique for each class,
211 // we have to use dynamic_cast here.
212 return dynamic_cast<const T*>(node);
213}
214
215// Represents a specific output produced by a node. Since the output of a node
216// can be composed by multiple outputs, the node+index coordinates fully qualify
217// each single output.
218struct TORCH_API Output {
219 struct Hasher {
220 size_t operator()(const Output& output) const;
221 };
222
223 Output() = default;
224 explicit Output(const Node* node, size_t index = 0)
225 : node(node), index(index) {}
226
227 hash_t hash() const;
228 hash_t shapeHash() const;
229
230 bool operator==(const Output& rhs) const {
231 return node == rhs.node && index == rhs.index;
232 }
233
234 // To compare the operands of to-be-constructed node and to-be-reused node
235 bool operator==(const Value& rhs) const;
236
237 bool operator!=(const Output& rhs) const {
238 return !operator==(rhs);
239 }
240
241 const Shape& shape() const {
242 return node->shape(index);
243 }
244
245 std::string ToString() const;
246
247 // The node providing the output.
248 const Node* node{nullptr};
249 // The index in the node's output this output refers to.
250 size_t index{0};
251};
252
253inline std::ostream& operator<<(std::ostream& stream, const Output& output) {
254 stream << output.ToString();
255 return stream;
256}
257
258template <typename T>
259using OutputMap = std::unordered_map<Output, T, Output::Hasher>;
260
261// Represents an input/operand for a Node object.
262struct TORCH_API Value {
263 Value() = default;
264 /* implicit */ Value(NodePtr&& node, size_t index = 0)
265 : node(std::move(node)), index(index) {}
266 /* implicit */ Value(const NodePtr& node, size_t index = 0)
267 : node(node), index(index) {}
268
269 hash_t hash() const;
270 hash_t shapeHash() const;
271
272 operator bool() const {
273 return node != nullptr;
274 }
275
276 operator Output() const {
277 return Output(node.get(), index);
278 }
279
280 const Shape& shape() const {
281 return node->shape(index);
282 }
283
284 Node* operator->() const {
285 return node.get();
286 }
287
288 NodePtr node;
289 size_t index = 0;
290};
291
292} // namespace lazy
293} // namespace torch
294
295namespace c10 {
296// Explicit template instantiation to make ArrayRef<Value> work
297template class at::ArrayRef<torch::lazy::Value>;
298} // namespace c10
299