1 | #pragma once |
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/macros/Export.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/Optional.h> |
7 | |
8 | #include <type.h> |
9 | #include <utils.h> |
10 | |
11 | #include <cstdint> |
12 | #include <iostream> |
13 | #include <limits> |
14 | #include <memory> |
15 | #include <stdexcept> |
16 | #include <unordered_map> |
17 | #include <vector> |
18 | |
19 | // TODO: Add more types (int32, int64) |
20 | // TODO: sameAs should have better logic to check against any type and return |
21 | // gracefully |
22 | |
23 | /* |
24 | * This file defines the base IR structure. Any IR node in this system will |
25 | * inherit from one of the following classes: Statement, Expr, Val, |
26 | * IrInputOutput IR is any information that the code generation stack may need |
27 | * for analysis. By analysis we're refering to anything done in response to a |
28 | * user facing call of this stack. This could be careful tracking of user calls, |
29 | * and any transformation including optimizing transformations, user declared |
30 | * transformations, and lowering the IR. |
31 | */ |
32 | |
33 | namespace torch { |
34 | namespace jit { |
35 | namespace fuser { |
36 | namespace cuda { |
37 | |
38 | using ValueId = int32_t; |
39 | |
40 | using StmtNameType = unsigned int; |
41 | |
42 | constexpr StmtNameType kInvalidStmName = |
43 | std::numeric_limits<unsigned int>::max(); |
44 | |
45 | class Fusion; |
46 | class FusionGuard; |
47 | class Expr; |
48 | class Val; |
49 | class UnaryOp; |
50 | class BinaryOp; |
51 | class RNGOp; |
52 | class IterDomain; |
53 | class IrCloner; |
54 | class IrContainer; |
55 | class IrBuilderPasskey; |
56 | class IrContainerPasskey; |
57 | |
58 | namespace kir { |
59 | class Kernel; |
60 | class Predicate; |
61 | } // namespace kir |
62 | |
63 | // Passkey for container to register names with statements |
64 | class ExprPasskey { |
65 | friend class Expr; |
66 | |
67 | private: |
68 | explicit ExprPasskey() {} |
69 | }; |
70 | |
71 | TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept; |
72 | |
73 | //! Statement is the highest level node representation. Everything that is |
74 | //! considered "IR" will be derived from this class at some point. Both Values |
75 | //! and Expr's are a Statement. If there will ever be any more fundamental |
76 | //! types, they will also derive from Statement. |
77 | //! |
78 | //! We use Statements to pass around nodes of unknown compile type. Therefore it |
79 | //! is also important for the design to have a dispatch system for a Statment. |
80 | //! Basically beinng able to succienctly traverse down the inhereitance stack of |
81 | //! a Statment at runtime. This is currently implemented in dispatch.h |
82 | class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase { |
83 | friend void swap(Fusion&, Fusion&) noexcept; |
84 | friend void swap(IrContainer& a, IrContainer& b) noexcept; |
85 | |
86 | public: |
87 | Statement() = delete; |
88 | |
89 | // Cloning constructor |
90 | Statement(const Statement* src, IrCloner* ir_cloner); |
91 | |
92 | // Dispatch functions, definitions in dispatch.cpp |
93 | template <typename T> |
94 | static void dispatch(T handler, Statement*); |
95 | |
96 | template <typename T> |
97 | static void constDispatch(T handler, const Statement* const); |
98 | |
99 | template <typename T> |
100 | static void mutatorDispatch(T mutator, Statement*); |
101 | |
102 | // Accessor functions to types. Vals always have a DataType, Exprs never do |
103 | virtual c10::optional<ValType> getValType() const { |
104 | return c10::nullopt; |
105 | } |
106 | virtual c10::optional<DataType> getDataType() const { |
107 | return c10::nullopt; |
108 | } |
109 | virtual c10::optional<ExprType> getExprType() const { |
110 | return c10::nullopt; |
111 | } |
112 | |
113 | // Short cut to figure out if it is a value/expression |
114 | bool isVal() const { |
115 | return getValType() != c10::nullopt; |
116 | } |
117 | bool isExpr() const { |
118 | return getExprType() != c10::nullopt; |
119 | } |
120 | |
121 | // Make sure this is a Val and return it as a Val* |
122 | Val* asVal(); |
123 | |
124 | // Make sure this is an Expr and return it as an Expr* |
125 | Expr* asExpr(); |
126 | |
127 | // Return the fusion this statement belongs to |
128 | Fusion* fusion() const; |
129 | |
130 | // Return the kernel this statement belongs to |
131 | kir::Kernel* kernel() const; |
132 | |
133 | // Return the container this statement belongs to |
134 | IrContainer* container() const { |
135 | return ir_container_; |
136 | } |
137 | |
138 | // Return the int that represents its name |
139 | StmtNameType name() const { |
140 | return name_; |
141 | } |
142 | |
143 | // Set the statements' name. Typically the container will set the name, |
144 | // however if we're dealing with cloning, IrBuilder will set the name, this |
145 | // maybe should be from IrCloner, however I didn't want to add another |
146 | // passkey. |
147 | void setName(IrContainerPasskey, StmtNameType name); |
148 | void setName(IrBuilderPasskey, StmtNameType name); |
149 | |
150 | virtual bool sameType(const Statement* const other) { |
151 | if (isVal() && other->isVal()) |
152 | return getValType().value() == other->getValType().value(); |
153 | if (isExpr() && other->isExpr()) |
154 | return getExprType().value() == other->getExprType().value(); |
155 | return false; |
156 | } |
157 | |
158 | // Return if this statement is the same as another statement |
159 | // TODO: should this run through dispatch on this and other? |
160 | virtual bool sameAs(const Statement* other) const { |
161 | return this == other; |
162 | } |
163 | |
164 | std::string toString() const; |
165 | std::string toInlineString() const; |
166 | |
167 | protected: |
168 | Statement(IrBuilderPasskey); |
169 | |
170 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
171 | StmtNameType name_ = kInvalidStmName; |
172 | |
173 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
174 | IrContainer* ir_container_ = nullptr; |
175 | }; |
176 | |
177 | //! A Val represents a "value." These are objects, like tensors, scalars, and |
178 | //! memory locations, that are inputs and outputs of computations (represented |
179 | //! by Exprs, below) |
180 | //! |
181 | //! Vals are constant and unique and should always be passed |
182 | //! around as a pointer. Val can generally be thought of as representing any |
183 | //! type of data. Some examples: a constant size like convolution filter width a |
184 | //! runtime constant like batch normalizations momentum a "symbolic" tensor like |
185 | //! one passed down from the JIT a memory buffer used in device code |
186 | //! |
187 | //! Adding a Val: |
188 | //! Right now adding a Val is quite involved. Val's can be defined in ir.h or in |
189 | //! their own header file. The following is what is currently needed to add a |
190 | //! new Val: |
191 | //! |
192 | //! 1) Definition inheriting from Val |
193 | //! - Members must be private or protected |
194 | //! - Accessor functions for members |
195 | //! - Must call Val constructor, Val constructor registers with fusion |
196 | //! - Implementation of bool sameAs(...) |
197 | //! - Must implement a "cloning" constructor, ex. |
198 | //! Int::Int(const Int* src, IrCloner* ir_cloner) |
199 | //! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val |
200 | //! 3) Default mutator function should be added to mutator.cpp |
201 | //! 4a) Printing functions should be added to ir_iostream.h/.cpp |
202 | //! 4b) Graphviz generation must be added to ir_graphviz.h/.cpp |
203 | //! 5) An enum value must be added to ValType in type.h |
204 | //! 6) A string entry must be added in val_type_string_map |
205 | //! |
206 | class TORCH_CUDA_CU_API Val : public Statement { |
207 | public: |
208 | explicit Val( |
209 | IrBuilderPasskey, |
210 | ValType _vtype, |
211 | DataType _dtype = DataType::Null); |
212 | |
213 | Val(const Val* src, IrCloner* ir_cloner); |
214 | |
215 | // Dispatch functions, definitions in dispatch.cpp |
216 | template <typename T> |
217 | static void dispatch(T handler, Val*); |
218 | |
219 | template <typename T> |
220 | static void constDispatch(T handler, const Val* const); |
221 | |
222 | template <typename T> |
223 | static void mutatorDispatch(T mutator, Val*); |
224 | |
225 | c10::optional<ValType> getValType() const override { |
226 | return vtype_; |
227 | } |
228 | |
229 | ValType vtype() const { |
230 | return vtype_; |
231 | } |
232 | |
233 | DataType dtype() const { |
234 | return dtype_; |
235 | } |
236 | |
237 | // Throws if no DataType is found. Vals must have a DataType |
238 | c10::optional<DataType> getDataType() const override; |
239 | |
240 | bool isScalar() const { |
241 | return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar; |
242 | } |
243 | |
244 | // Returns if all dependencies are constant scalars |
245 | bool isConstScalar() const; |
246 | |
247 | // Returns if all dependencies are constant integers |
248 | bool isConstInt() const; |
249 | |
250 | bool isAnInt() const { |
251 | return isScalar() && dtype_ == DataType::Int; |
252 | } |
253 | |
254 | bool isADouble() const { |
255 | return isScalar() && dtype_ == DataType::Double; |
256 | } |
257 | |
258 | // If this Val is an integer with a direct constant value associated with it, |
259 | // will return the value of that constant integer. If this integer has |
260 | // defining expressions it will return a c10::nullopt. Those values should be |
261 | // infered using evaluateInt. |
262 | c10::optional<int64_t> getInt() const; |
263 | |
264 | // If this Val is a double with a direct constant value associated with it, |
265 | // will return the value of that constant double. If this double has |
266 | // defining expressions it will return a c10::nullopt. Those values should be |
267 | // infered using evaluateDouble. |
268 | c10::optional<double> getDouble() const; |
269 | |
270 | // If this Val is a constant integer, and its history is comprised only of |
271 | // constant values, will return the value of that constant integer. Cannot |
272 | // make constant as expression evaluator takes non-constant Vals. |
273 | int64_t evaluateInt(); |
274 | |
275 | // If this Val is a constant double, and its history is comprised only of |
276 | // constant values, will return the value of that constant double. Cannot |
277 | // make constant as expression evaluator takes non-constant Vals. |
278 | double evaluateDouble(); |
279 | |
280 | // Returns if no dependencies and is a constant scalar. |
281 | virtual bool isConst() const { |
282 | return false; |
283 | } |
284 | |
285 | bool isZeroInt() const; |
286 | bool isOneInt() const; |
287 | |
288 | // Returns the Expr that this value is an output of, returns nullptr if none |
289 | // was found |
290 | Expr* definition() const { |
291 | if (is_fusion_input_) { |
292 | return nullptr; |
293 | } |
294 | return definition_; |
295 | } |
296 | |
297 | // Determine if value definition matches given expression type |
298 | bool isDefinitionType(ExprType expression_type) const; |
299 | |
300 | const std::vector<Expr*>& uses() const; |
301 | |
302 | bool isFusionInput() const { |
303 | return is_fusion_input_; |
304 | } |
305 | |
306 | bool isFusionOutput() const { |
307 | return is_fusion_output_; |
308 | } |
309 | |
310 | //! Returns true when other is a producer of this |
311 | bool isProducerOf(const Val* other) const; |
312 | |
313 | //! Returns true when other is a consumer of this |
314 | bool isConsumerOf(const Val* other) const; |
315 | |
316 | bool sameType(const Statement* other) override { |
317 | return Statement::sameType(other) && |
318 | getDataType() == other->as<Val>()->getDataType(); |
319 | } |
320 | |
321 | // TODO: Make this more sophisticated. A value being the same as another value |
322 | // should be evaluated based on the DAG that created it, and that DAGs leaf |
323 | // nodes |
324 | bool sameAs(const Statement* other) const override { |
325 | return this == other; |
326 | } |
327 | |
328 | void setEvaluatorIndex(int to) { |
329 | TORCH_INTERNAL_ASSERT(evaluator_index_ == -1); |
330 | evaluator_index_ = to; |
331 | } |
332 | |
333 | int evaluatorIndex() const { |
334 | return evaluator_index_; |
335 | } |
336 | |
337 | // Following is managed by Fusion (or kirIrBuilder) and can change. |
338 | // TODO: Protect with a passkey. |
339 | void setDefinition(Expr* expr) { |
340 | definition_ = expr; |
341 | } |
342 | |
343 | void resolveIndexDtype(); |
344 | |
345 | protected: |
346 | friend Fusion; |
347 | |
348 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
349 | const ValType vtype_; |
350 | |
351 | // TODO: Add fusion passkey for this |
352 | void setIsFusionInput(bool is_fusion_input) { |
353 | is_fusion_input_ = is_fusion_input; |
354 | } |
355 | |
356 | // TODO: Add fusion passkey for this |
357 | void setIsFusionOutput(bool is_fusion_output) { |
358 | is_fusion_output_ = is_fusion_output; |
359 | } |
360 | |
361 | // TODO: Add fusion or container passkey for this |
362 | void setUses(const std::vector<Expr*>& uses) { |
363 | uses_ = uses; |
364 | } |
365 | |
366 | private: |
367 | // There's only one instance where dtype can change, and that's through |
368 | // resolving the index data type from nvfuser to either Int or Int32 for |
369 | // welford operations. |
370 | DataType dtype_; |
371 | |
372 | // Following is managed by Fusion and can change. |
373 | bool is_fusion_input_ = false; |
374 | bool is_fusion_output_ = false; |
375 | |
376 | Expr* definition_ = nullptr; |
377 | std::vector<Expr*> uses_; |
378 | |
379 | // Expr evaluator idx; |
380 | int evaluator_index_ = -1; |
381 | }; |
382 | |
383 | //! A Expr represents a "computation." These are functions that takes inputs |
384 | //! and produce outputs, inputs and outputs all being Vals. There are |
385 | //! specializations of BinaryOp which takes 2 inputs and produces 1 output, and |
386 | //! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and |
387 | //! immutable. Conceptually, Exprs could always be manipulated using unique |
388 | //! pointers, and we could add this later. However, for now Exprs can be |
389 | //! replaced in a fusion, but they cannot be modified in place. |
390 | //! |
391 | //! The IR is static single assignment (SSA). Values can only be defined as an |
392 | //! output of an Expr once. If they are re-defined the original definition is |
393 | //! deleted from the program, as opposed to an ordered redefinition of the |
394 | //! value in the program. |
395 | //! |
396 | //! Note: Registering an Expr with a Fusion is actually 2 parts, one part is |
397 | //! done in the Expr constructor, so that should be called on anything that |
398 | //! inherits Expr. The issue with having registration in Expr's constructor, is |
399 | //! that the constructor of an Expr will set ouputs and inputs. This |
400 | //! information is important for registration with Fuser, so it can track the |
401 | //! dependency chain. |
402 | //! |
403 | //! Adding an Expr: |
404 | //! Right now adding an Expr is quite involved. Expr's can be defined in ir.h |
405 | //! or in their own header file. The following is what is currently needed for |
406 | //! Expr definitions: |
407 | //! |
408 | //! 1) Definition inheriting from Expr. |
409 | //! - Members must be private or protected |
410 | //! - Accessor functions for members |
411 | //! - Constructors need to register with the Fusion after inputs/outputs |
412 | //! are defined |
413 | //! - Implementation of bool sameAs(...) |
414 | //! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val |
415 | //! 3) Default mutator function should be added to mutator.h/.cpp |
416 | //! 4) Printing functions should be added to ir_iostream.h/.cpp |
417 | //! 5) Lower case convenience functions should be added to arith.h/.cpp (If |
418 | //! user facing) |
419 | //! 6) An enum value must be added to ExprType in type.h |
420 | //! 7) A string entry must be added in expr_type_string_map |
421 | //! 8) Entry added to ir_graphviz .cpp/.h |
422 | //! |
423 | class TORCH_CUDA_CU_API Expr : public Statement { |
424 | public: |
425 | explicit Expr(IrBuilderPasskey, ExprType type); |
426 | |
427 | Expr(const Expr* src, IrCloner* ir_cloner); |
428 | |
429 | // Creates a new instance of the expression with all its field copied. |
430 | // Note that unlike IrCloner, this function only do a shallow copy |
431 | virtual Expr* shallowCopy() const = 0; |
432 | |
433 | c10::optional<ExprType> getExprType() const override { |
434 | return etype_; |
435 | } |
436 | |
437 | ExprType etype() const { |
438 | return etype_; |
439 | } |
440 | |
441 | bool sameAs(const Statement* other) const override; |
442 | |
443 | // Input/output accessors |
444 | const auto& inputs() const { |
445 | return inputs_; |
446 | } |
447 | |
448 | const auto& outputs() const { |
449 | return outputs_; |
450 | } |
451 | |
452 | auto input(size_t index) const { |
453 | return inputs_[index]; |
454 | } |
455 | |
456 | auto output(size_t index) const { |
457 | return outputs_[index]; |
458 | } |
459 | |
460 | // Dispatch functions, definitions in dispatch.cpp |
461 | template <typename T> |
462 | static void dispatch(T handler, Expr*); |
463 | |
464 | template <typename T> |
465 | static void constDispatch(T handler, const Expr* const); |
466 | |
467 | template <typename T> |
468 | static void mutatorDispatch(T mutator, Expr*); |
469 | |
470 | // TODO: Protect based on being in kernel container |
471 | kir::Predicate* predicate() const; |
472 | |
473 | // Creates a shallow copy the expression with the given predicate attached. |
474 | // TODO: Protect based on being in kernel container |
475 | Expr* withPredicate(kir::Predicate* predicate); |
476 | |
477 | // TODO: Protect based on being in kernel container |
478 | kir::Predicate* writePredicate() const; |
479 | |
480 | // Creates a shallow copy the expression with the given write-predicate |
481 | // attached. |
482 | // TODO: Protect based on being in kernel container |
483 | Expr* withWritePredicate(kir::Predicate* write_predicate); |
484 | |
485 | protected: |
486 | // TODO: Protect based on being in kernel container |
487 | void setPredicate(kir::Predicate* predicate); |
488 | |
489 | // TODO: Protect based on being in kernel container |
490 | void setWritePredicate(kir::Predicate* write_predicate); |
491 | |
492 | void copyPredicatesFrom(const Expr* expr); |
493 | |
494 | // TODO: Add Fusion passkey |
495 | void addInput(Val* input) { |
496 | TORCH_INTERNAL_ASSERT(input != nullptr); |
497 | inputs_.push_back(input); |
498 | } |
499 | |
500 | // TODO: Add Fusion passkey |
501 | void addOutput(Val* output) { |
502 | TORCH_INTERNAL_ASSERT(output != nullptr); |
503 | outputs_.push_back(output); |
504 | } |
505 | |
506 | ExprPasskey exprPasskey() { |
507 | return ExprPasskey(); |
508 | } |
509 | |
510 | private: |
511 | ExprType etype_ = ExprType::Invalid; |
512 | std::vector<Val*> inputs_; |
513 | std::vector<Val*> outputs_; |
514 | |
515 | kir::Predicate* predicate_ = nullptr; |
516 | |
517 | // Only used for reduction-related expressions |
518 | kir::Predicate* write_predicate_ = nullptr; |
519 | }; |
520 | |
521 | } // namespace cuda |
522 | } // namespace fuser |
523 | } // namespace jit |
524 | } // namespace torch |
525 | |