1#pragma once
2
3#include <c10/core/ScalarType.h>
4#include <c10/macros/Export.h>
5#include <c10/util/Exception.h>
6#include <c10/util/Optional.h>
7
8#include <type.h>
9#include <utils.h>
10
11#include <cstdint>
12#include <iostream>
13#include <limits>
14#include <memory>
15#include <stdexcept>
16#include <unordered_map>
17#include <vector>
18
19// TODO: Add more types (int32, int64)
20// TODO: sameAs should have better logic to check against any type and return
21// gracefully
22
23/*
24 * This file defines the base IR structure. Any IR node in this system will
25 * inherit from one of the following classes: Statement, Expr, Val,
26 * IrInputOutput IR is any information that the code generation stack may need
27 * for analysis. By analysis we're refering to anything done in response to a
28 * user facing call of this stack. This could be careful tracking of user calls,
29 * and any transformation including optimizing transformations, user declared
30 * transformations, and lowering the IR.
31 */
32
33namespace torch {
34namespace jit {
35namespace fuser {
36namespace cuda {
37
38using ValueId = int32_t;
39
40using StmtNameType = unsigned int;
41
42constexpr StmtNameType kInvalidStmName =
43 std::numeric_limits<unsigned int>::max();
44
45class Fusion;
46class FusionGuard;
47class Expr;
48class Val;
49class UnaryOp;
50class BinaryOp;
51class RNGOp;
52class IterDomain;
53class IrCloner;
54class IrContainer;
55class IrBuilderPasskey;
56class IrContainerPasskey;
57
58namespace kir {
59class Kernel;
60class Predicate;
61} // namespace kir
62
63// Passkey for container to register names with statements
64class ExprPasskey {
65 friend class Expr;
66
67 private:
68 explicit ExprPasskey() {}
69};
70
71TORCH_CUDA_CU_API void swap(Fusion& a, Fusion& b) noexcept;
72
73//! Statement is the highest level node representation. Everything that is
74//! considered "IR" will be derived from this class at some point. Both Values
75//! and Expr's are a Statement. If there will ever be any more fundamental
76//! types, they will also derive from Statement.
77//!
78//! We use Statements to pass around nodes of unknown compile type. Therefore it
79//! is also important for the design to have a dispatch system for a Statment.
80//! Basically beinng able to succienctly traverse down the inhereitance stack of
81//! a Statment at runtime. This is currently implemented in dispatch.h
82class TORCH_CUDA_CU_API Statement : public NonCopyable, public PolymorphicBase {
83 friend void swap(Fusion&, Fusion&) noexcept;
84 friend void swap(IrContainer& a, IrContainer& b) noexcept;
85
86 public:
87 Statement() = delete;
88
89 // Cloning constructor
90 Statement(const Statement* src, IrCloner* ir_cloner);
91
92 // Dispatch functions, definitions in dispatch.cpp
93 template <typename T>
94 static void dispatch(T handler, Statement*);
95
96 template <typename T>
97 static void constDispatch(T handler, const Statement* const);
98
99 template <typename T>
100 static void mutatorDispatch(T mutator, Statement*);
101
102 // Accessor functions to types. Vals always have a DataType, Exprs never do
103 virtual c10::optional<ValType> getValType() const {
104 return c10::nullopt;
105 }
106 virtual c10::optional<DataType> getDataType() const {
107 return c10::nullopt;
108 }
109 virtual c10::optional<ExprType> getExprType() const {
110 return c10::nullopt;
111 }
112
113 // Short cut to figure out if it is a value/expression
114 bool isVal() const {
115 return getValType() != c10::nullopt;
116 }
117 bool isExpr() const {
118 return getExprType() != c10::nullopt;
119 }
120
121 // Make sure this is a Val and return it as a Val*
122 Val* asVal();
123
124 // Make sure this is an Expr and return it as an Expr*
125 Expr* asExpr();
126
127 // Return the fusion this statement belongs to
128 Fusion* fusion() const;
129
130 // Return the kernel this statement belongs to
131 kir::Kernel* kernel() const;
132
133 // Return the container this statement belongs to
134 IrContainer* container() const {
135 return ir_container_;
136 }
137
138 // Return the int that represents its name
139 StmtNameType name() const {
140 return name_;
141 }
142
143 // Set the statements' name. Typically the container will set the name,
144 // however if we're dealing with cloning, IrBuilder will set the name, this
145 // maybe should be from IrCloner, however I didn't want to add another
146 // passkey.
147 void setName(IrContainerPasskey, StmtNameType name);
148 void setName(IrBuilderPasskey, StmtNameType name);
149
150 virtual bool sameType(const Statement* const other) {
151 if (isVal() && other->isVal())
152 return getValType().value() == other->getValType().value();
153 if (isExpr() && other->isExpr())
154 return getExprType().value() == other->getExprType().value();
155 return false;
156 }
157
158 // Return if this statement is the same as another statement
159 // TODO: should this run through dispatch on this and other?
160 virtual bool sameAs(const Statement* other) const {
161 return this == other;
162 }
163
164 std::string toString() const;
165 std::string toInlineString() const;
166
167 protected:
168 Statement(IrBuilderPasskey);
169
170 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
171 StmtNameType name_ = kInvalidStmName;
172
173 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
174 IrContainer* ir_container_ = nullptr;
175};
176
177//! A Val represents a "value." These are objects, like tensors, scalars, and
178//! memory locations, that are inputs and outputs of computations (represented
179//! by Exprs, below)
180//!
181//! Vals are constant and unique and should always be passed
182//! around as a pointer. Val can generally be thought of as representing any
183//! type of data. Some examples: a constant size like convolution filter width a
184//! runtime constant like batch normalizations momentum a "symbolic" tensor like
185//! one passed down from the JIT a memory buffer used in device code
186//!
187//! Adding a Val:
188//! Right now adding a Val is quite involved. Val's can be defined in ir.h or in
189//! their own header file. The following is what is currently needed to add a
190//! new Val:
191//!
192//! 1) Definition inheriting from Val
193//! - Members must be private or protected
194//! - Accessor functions for members
195//! - Must call Val constructor, Val constructor registers with fusion
196//! - Implementation of bool sameAs(...)
197//! - Must implement a "cloning" constructor, ex.
198//! Int::Int(const Int* src, IrCloner* ir_cloner)
199//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
200//! 3) Default mutator function should be added to mutator.cpp
201//! 4a) Printing functions should be added to ir_iostream.h/.cpp
202//! 4b) Graphviz generation must be added to ir_graphviz.h/.cpp
203//! 5) An enum value must be added to ValType in type.h
204//! 6) A string entry must be added in val_type_string_map
205//!
206class TORCH_CUDA_CU_API Val : public Statement {
207 public:
208 explicit Val(
209 IrBuilderPasskey,
210 ValType _vtype,
211 DataType _dtype = DataType::Null);
212
213 Val(const Val* src, IrCloner* ir_cloner);
214
215 // Dispatch functions, definitions in dispatch.cpp
216 template <typename T>
217 static void dispatch(T handler, Val*);
218
219 template <typename T>
220 static void constDispatch(T handler, const Val* const);
221
222 template <typename T>
223 static void mutatorDispatch(T mutator, Val*);
224
225 c10::optional<ValType> getValType() const override {
226 return vtype_;
227 }
228
229 ValType vtype() const {
230 return vtype_;
231 }
232
233 DataType dtype() const {
234 return dtype_;
235 }
236
237 // Throws if no DataType is found. Vals must have a DataType
238 c10::optional<DataType> getDataType() const override;
239
240 bool isScalar() const {
241 return vtype_ == ValType::Scalar || vtype_ == ValType::NamedScalar;
242 }
243
244 // Returns if all dependencies are constant scalars
245 bool isConstScalar() const;
246
247 // Returns if all dependencies are constant integers
248 bool isConstInt() const;
249
250 bool isAnInt() const {
251 return isScalar() && dtype_ == DataType::Int;
252 }
253
254 bool isADouble() const {
255 return isScalar() && dtype_ == DataType::Double;
256 }
257
258 // If this Val is an integer with a direct constant value associated with it,
259 // will return the value of that constant integer. If this integer has
260 // defining expressions it will return a c10::nullopt. Those values should be
261 // infered using evaluateInt.
262 c10::optional<int64_t> getInt() const;
263
264 // If this Val is a double with a direct constant value associated with it,
265 // will return the value of that constant double. If this double has
266 // defining expressions it will return a c10::nullopt. Those values should be
267 // infered using evaluateDouble.
268 c10::optional<double> getDouble() const;
269
270 // If this Val is a constant integer, and its history is comprised only of
271 // constant values, will return the value of that constant integer. Cannot
272 // make constant as expression evaluator takes non-constant Vals.
273 int64_t evaluateInt();
274
275 // If this Val is a constant double, and its history is comprised only of
276 // constant values, will return the value of that constant double. Cannot
277 // make constant as expression evaluator takes non-constant Vals.
278 double evaluateDouble();
279
280 // Returns if no dependencies and is a constant scalar.
281 virtual bool isConst() const {
282 return false;
283 }
284
285 bool isZeroInt() const;
286 bool isOneInt() const;
287
288 // Returns the Expr that this value is an output of, returns nullptr if none
289 // was found
290 Expr* definition() const {
291 if (is_fusion_input_) {
292 return nullptr;
293 }
294 return definition_;
295 }
296
297 // Determine if value definition matches given expression type
298 bool isDefinitionType(ExprType expression_type) const;
299
300 const std::vector<Expr*>& uses() const;
301
302 bool isFusionInput() const {
303 return is_fusion_input_;
304 }
305
306 bool isFusionOutput() const {
307 return is_fusion_output_;
308 }
309
310 //! Returns true when other is a producer of this
311 bool isProducerOf(const Val* other) const;
312
313 //! Returns true when other is a consumer of this
314 bool isConsumerOf(const Val* other) const;
315
316 bool sameType(const Statement* other) override {
317 return Statement::sameType(other) &&
318 getDataType() == other->as<Val>()->getDataType();
319 }
320
321 // TODO: Make this more sophisticated. A value being the same as another value
322 // should be evaluated based on the DAG that created it, and that DAGs leaf
323 // nodes
324 bool sameAs(const Statement* other) const override {
325 return this == other;
326 }
327
328 void setEvaluatorIndex(int to) {
329 TORCH_INTERNAL_ASSERT(evaluator_index_ == -1);
330 evaluator_index_ = to;
331 }
332
333 int evaluatorIndex() const {
334 return evaluator_index_;
335 }
336
337 // Following is managed by Fusion (or kirIrBuilder) and can change.
338 // TODO: Protect with a passkey.
339 void setDefinition(Expr* expr) {
340 definition_ = expr;
341 }
342
343 void resolveIndexDtype();
344
345 protected:
346 friend Fusion;
347
348 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
349 const ValType vtype_;
350
351 // TODO: Add fusion passkey for this
352 void setIsFusionInput(bool is_fusion_input) {
353 is_fusion_input_ = is_fusion_input;
354 }
355
356 // TODO: Add fusion passkey for this
357 void setIsFusionOutput(bool is_fusion_output) {
358 is_fusion_output_ = is_fusion_output;
359 }
360
361 // TODO: Add fusion or container passkey for this
362 void setUses(const std::vector<Expr*>& uses) {
363 uses_ = uses;
364 }
365
366 private:
367 // There's only one instance where dtype can change, and that's through
368 // resolving the index data type from nvfuser to either Int or Int32 for
369 // welford operations.
370 DataType dtype_;
371
372 // Following is managed by Fusion and can change.
373 bool is_fusion_input_ = false;
374 bool is_fusion_output_ = false;
375
376 Expr* definition_ = nullptr;
377 std::vector<Expr*> uses_;
378
379 // Expr evaluator idx;
380 int evaluator_index_ = -1;
381};
382
383//! A Expr represents a "computation." These are functions that takes inputs
384//! and produce outputs, inputs and outputs all being Vals. There are
385//! specializations of BinaryOp which takes 2 inputs and produces 1 output, and
386//! UnaryOp which takes 1 input and produces 1 output. Exprs are unique and
387//! immutable. Conceptually, Exprs could always be manipulated using unique
388//! pointers, and we could add this later. However, for now Exprs can be
389//! replaced in a fusion, but they cannot be modified in place.
390//!
391//! The IR is static single assignment (SSA). Values can only be defined as an
392//! output of an Expr once. If they are re-defined the original definition is
393//! deleted from the program, as opposed to an ordered redefinition of the
394//! value in the program.
395//!
396//! Note: Registering an Expr with a Fusion is actually 2 parts, one part is
397//! done in the Expr constructor, so that should be called on anything that
398//! inherits Expr. The issue with having registration in Expr's constructor, is
399//! that the constructor of an Expr will set ouputs and inputs. This
400//! information is important for registration with Fuser, so it can track the
401//! dependency chain.
402//!
403//! Adding an Expr:
404//! Right now adding an Expr is quite involved. Expr's can be defined in ir.h
405//! or in their own header file. The following is what is currently needed for
406//! Expr definitions:
407//!
408//! 1) Definition inheriting from Expr.
409//! - Members must be private or protected
410//! - Accessor functions for members
411//! - Constructors need to register with the Fusion after inputs/outputs
412//! are defined
413//! - Implementation of bool sameAs(...)
414//! 2) dispatch.h/.cpp must be updated to include dispatch of the new Val
415//! 3) Default mutator function should be added to mutator.h/.cpp
416//! 4) Printing functions should be added to ir_iostream.h/.cpp
417//! 5) Lower case convenience functions should be added to arith.h/.cpp (If
418//! user facing)
419//! 6) An enum value must be added to ExprType in type.h
420//! 7) A string entry must be added in expr_type_string_map
421//! 8) Entry added to ir_graphviz .cpp/.h
422//!
423class TORCH_CUDA_CU_API Expr : public Statement {
424 public:
425 explicit Expr(IrBuilderPasskey, ExprType type);
426
427 Expr(const Expr* src, IrCloner* ir_cloner);
428
429 // Creates a new instance of the expression with all its field copied.
430 // Note that unlike IrCloner, this function only do a shallow copy
431 virtual Expr* shallowCopy() const = 0;
432
433 c10::optional<ExprType> getExprType() const override {
434 return etype_;
435 }
436
437 ExprType etype() const {
438 return etype_;
439 }
440
441 bool sameAs(const Statement* other) const override;
442
443 // Input/output accessors
444 const auto& inputs() const {
445 return inputs_;
446 }
447
448 const auto& outputs() const {
449 return outputs_;
450 }
451
452 auto input(size_t index) const {
453 return inputs_[index];
454 }
455
456 auto output(size_t index) const {
457 return outputs_[index];
458 }
459
460 // Dispatch functions, definitions in dispatch.cpp
461 template <typename T>
462 static void dispatch(T handler, Expr*);
463
464 template <typename T>
465 static void constDispatch(T handler, const Expr* const);
466
467 template <typename T>
468 static void mutatorDispatch(T mutator, Expr*);
469
470 // TODO: Protect based on being in kernel container
471 kir::Predicate* predicate() const;
472
473 // Creates a shallow copy the expression with the given predicate attached.
474 // TODO: Protect based on being in kernel container
475 Expr* withPredicate(kir::Predicate* predicate);
476
477 // TODO: Protect based on being in kernel container
478 kir::Predicate* writePredicate() const;
479
480 // Creates a shallow copy the expression with the given write-predicate
481 // attached.
482 // TODO: Protect based on being in kernel container
483 Expr* withWritePredicate(kir::Predicate* write_predicate);
484
485 protected:
486 // TODO: Protect based on being in kernel container
487 void setPredicate(kir::Predicate* predicate);
488
489 // TODO: Protect based on being in kernel container
490 void setWritePredicate(kir::Predicate* write_predicate);
491
492 void copyPredicatesFrom(const Expr* expr);
493
494 // TODO: Add Fusion passkey
495 void addInput(Val* input) {
496 TORCH_INTERNAL_ASSERT(input != nullptr);
497 inputs_.push_back(input);
498 }
499
500 // TODO: Add Fusion passkey
501 void addOutput(Val* output) {
502 TORCH_INTERNAL_ASSERT(output != nullptr);
503 outputs_.push_back(output);
504 }
505
506 ExprPasskey exprPasskey() {
507 return ExprPasskey();
508 }
509
510 private:
511 ExprType etype_ = ExprType::Invalid;
512 std::vector<Val*> inputs_;
513 std::vector<Val*> outputs_;
514
515 kir::Predicate* predicate_ = nullptr;
516
517 // Only used for reduction-related expressions
518 kir::Predicate* write_predicate_ = nullptr;
519};
520
521} // namespace cuda
522} // namespace fuser
523} // namespace jit
524} // namespace torch
525