1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <fusion.h>
6#include <ir_base_nodes.h>
7#include <ir_internal_nodes.h>
8#include <mma_type.h>
9
10#include <torch/csrc/jit/ir/ir.h>
11
12//! Nodes in here are intended to be "user facing" users in this sense being
13//! those that want to be able to generate CUDA code.
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20class WelfordResult;
21class ViewTransform;
22
23class IrCloner;
24class IrBuilderPasskey;
25
26//! A Bool value
27//!
28//! This value can be a symbolic value (defined after the kernel
29//! is compiled) or a constant value (inlined into the kernel definition).
30//!
31class TORCH_CUDA_CU_API Bool : public Val {
32 public:
33 Bool(IrBuilderPasskey passkey);
34
35 explicit Bool(IrBuilderPasskey passkey, bool value);
36
37 explicit Bool(IrBuilderPasskey passkey, c10::optional<bool> value);
38
39 Bool(const Bool* src, IrCloner* ir_cloner);
40
41 bool isSymbolic() const {
42 return !(maybe_value_.has_value());
43 }
44 bool isConst() const final {
45 return maybe_value_.has_value();
46 }
47 c10::optional<bool> value() const {
48 return maybe_value_;
49 }
50
51 bool sameAs(const Statement* other) const override;
52
53 private:
54 const c10::optional<bool> maybe_value_;
55};
56
57//! A Float64 value. This value can be a symbolic value (defined after the
58//! kernel is compiled) or a constant value (inlined into the kernel
59//! definition).
60class TORCH_CUDA_CU_API Double : public Val {
61 public:
62 using ScalarType = double;
63
64 Double(IrBuilderPasskey passkey);
65
66 explicit Double(IrBuilderPasskey passkey, ScalarType value);
67
68 explicit Double(IrBuilderPasskey passkey, c10::optional<ScalarType> value);
69
70 Double(const Double* src, IrCloner* ir_cloner);
71
72 bool isSymbolic() const {
73 return !(maybe_value_.has_value());
74 }
75 bool isConst() const final {
76 return maybe_value_.has_value();
77 }
78 c10::optional<ScalarType> value() const {
79 return maybe_value_;
80 }
81
82 bool sameAs(const Statement* other) const override;
83
84 private:
85 const c10::optional<ScalarType> maybe_value_;
86};
87
88//! An Int64 value. If used for indexing it's set as size_t. Otherwise it's an
89//! inlined literal in the kernel.
90class TORCH_CUDA_CU_API Int : public Val {
91 public:
92 using ScalarType = int64_t;
93
94 Int(IrBuilderPasskey passkey);
95
96 explicit Int(IrBuilderPasskey passkey, ScalarType value);
97
98 explicit Int(IrBuilderPasskey passkey, c10::optional<ScalarType> value);
99
100 Int(const Int* src, IrCloner* ir_cloner);
101
102 bool isSymbolic() const {
103 return !(maybe_value_.has_value());
104 }
105 bool isConst() const final {
106 return maybe_value_.has_value();
107 }
108 c10::optional<ScalarType> value() const {
109 return maybe_value_;
110 }
111
112 bool sameAs(const Statement* other) const override;
113
114 private:
115 const c10::optional<ScalarType> maybe_value_;
116};
117
118//! An c10::complex<double> value. This value can be a symbolic value (defined
119//! after the kernel is compiled) or a constant value (inlined into the kernel
120//! definition).
121class TORCH_CUDA_CU_API ComplexDouble : public Val {
122 public:
123 using ScalarType = c10::complex<double>;
124
125 ComplexDouble(IrBuilderPasskey passkey);
126
127 explicit ComplexDouble(IrBuilderPasskey passkey, ScalarType value);
128
129 explicit ComplexDouble(
130 IrBuilderPasskey passkey,
131 c10::optional<ScalarType> value);
132
133 ComplexDouble(const ComplexDouble* src, IrCloner* ir_cloner);
134
135 bool isSymbolic() const {
136 return !(maybe_value_.has_value());
137 }
138 bool isConst() const final {
139 return maybe_value_.has_value();
140 }
141 c10::optional<ScalarType> value() const {
142 return maybe_value_;
143 }
144
145 bool sameAs(const Statement* other) const override;
146
147 private:
148 const c10::optional<ScalarType> maybe_value_;
149};
150
151//! Mode during propagation of computeAt, standard will throw an error if
152//! computeAt position provided can't be satisfied, best effort will lower the
153//! computeAt position as needed during traversal, most inlined will increase
154//! the compute at position to maximum possible through traversal.
155enum class ComputeAtMode { Standard, BestEffort, MostInlined };
156
157class TransformPropagator;
158struct MostInlinedTransformPropagator;
159class TransformIter;
160class TransformReplay;
161class OptOutMutator;
162class TensorDomain;
163
164class MaxPosCalculator;
165
166namespace ir_utils {
167class TVDomainGuard;
168}
169
170//! TensorView is our primitive Tensor Type used in code generation. It can be
171//! thought of as representing physical memory, however, its dimensionality is
172//! modifed as split/merge/computeAt functions are called. The history of
173//! these transformations are kept and used for generating actual code
174//! referncing physical memory. Generally when users are thinking of code
175//! generation in reference to a Tensor, this is the class they should be
176//! interacting with.
177//!
178//! The reason we need both TensorView and TensorDomain is that we need to have
179//! a record of both what is being computed and how it is being computed. For
180//! example we may have the operation:
181//!
182//! TV3[I, J, K] = TV2[I, J, K] + TV1[I, J, K]
183//!
184//! The mathematical operations here are on the tensor views TV1, TV2, and
185//! TV3. This operation is a pointwise operation. To compute this pointwise
186//! operation we iterate over the 3D TensorDomain [I, J, K], where K is the
187//! fastest changing dimension.
188//!
189//! \todo Need to work on the const model for TensorView, making all functions
190//! that should be const, const. Gave this a try but expanded really quickly.
191//! getComputeAtAxis not being const because it can return a TV that some expect
192//! to be non-const is the biggest headache.
193//!
194class TORCH_CUDA_CU_API TensorView : public Val {
195 public:
196 TensorView(
197 IrBuilderPasskey passkey,
198 TensorDomain* domain,
199 DataType dtype,
200 MemoryType mtype = MemoryType::Local);
201
202 explicit TensorView(
203 IrBuilderPasskey passkey,
204 const std::shared_ptr<c10::TensorType>& tensor_type);
205
206 explicit TensorView(
207 IrBuilderPasskey passkey,
208 const std::shared_ptr<Value>& jit_value);
209
210 TensorView(const TensorView* src, IrCloner* ir_cloner);
211
212 TensorDomain* domain() const {
213 return domain_;
214 }
215
216 //! This is for a TensorView with an rFactor domain that is an input to a
217 //! fusion segment. We convert the rfactor domain into a new root domain.
218 //! Any dynamic-sized rfactor iterDomains are given a new symbolic extent.
219 //! Concrete integer extents are kept. Output TensorViews of any subsequent
220 //! expressions that use this TensorView are also updated.
221 void convertRfactorToRootDomain();
222
223 void setContiguity(const std::vector<bool>& contig) {
224 domain()->setContiguity(contig);
225 }
226
227 void setContiguity(bool contig) {
228 setContiguity(std::vector<bool>(domain()->contiguity().size(), contig));
229 }
230
231 bool hasReduction() const;
232 bool hasBlockReduction() const;
233 bool hasGridReduction() const;
234 bool hasBroadcast() const;
235 bool hasRFactor() const;
236
237 //! This is the previous hasReduction logic,
238 //! kept here exclusively for lower loop pass will
239 //! deprecate when Fusion IR pass can convert
240 //! trivial reductions
241 bool hasAnyReduction() const;
242
243 //! Returns true if this tensor is zero dimensional,
244 //! i.e. a wrapped scalar or an empty placeholder.
245 bool isZeroDim() const {
246 return nDims() == 0;
247 }
248
249 //! Returns true if this tensor does not contain
250 //! any value.
251 bool isEmptyTensor() const;
252
253 c10::optional<unsigned int> getReductionAxis() const;
254
255 const std::vector<IterDomain*>& getRootDomain() const;
256
257 const std::vector<IterDomain*>& getRFactorDomain() const;
258
259 // If rfactor domain exists in domain() return it, otherwise return root
260 // domain.
261 const std::vector<IterDomain*>& getMaybeRFactorDomain() const;
262
263 IterDomain* axis(int pos) const;
264
265 // Does it share outer axes with other tensors?
266 bool hasComputeAt() const {
267 return compute_at_pos_ > 0;
268 }
269
270 bool hasMaxProducerPosition() const {
271 return max_producer_pos_ > 0;
272 }
273
274 size_t nDims() const;
275
276 // sets cpu_scalar_ value, which is special handling for CPU based zero-dim
277 // tensors (i.e. CPU Tensors that only have one value). This is only used if
278 // on an input value, otherwise ignored. This is important as special handling
279 // because these "scalars" should be type promoted as a tensor, but we want to
280 // avoid explicit copying of the data, so we want to pass the data value as a
281 // standard kernel argument value.
282 void setCpuScalar(bool is_cpu_scalar);
283
284 // returns cpu_scalar_ value, which is special handling for CPU based zero-dim
285 // tensors (i.e. CPU Tensors that only have one value). This is only used if
286 // on an input value, otherwise ignored. This is important as special handling
287 // because these "scalars" should be type promoted as a tensor, but we want to
288 // avoid explicit copying of the data, so we want to pass the data value as a
289 // standard kernel argument value.
290 bool isCpuScalar() const {
291 return cpu_scalar_;
292 }
293
294 // Returns the position that this tensor is produced at relative to its axes.
295 unsigned int getComputeAtPosition() const {
296 return compute_at_pos_;
297 }
298
299 // Returns the maximum position of producers are being computed at relative to
300 // this tensor. This position dictates the clear expectations of producers.
301 unsigned int getMaxProducerPosition() const {
302 return max_producer_pos_;
303 }
304
305 //! This is used when we disconnect a tensorview from a reduction
306 //! operation and connect it to a non-reduction operator. We need
307 //! to remove the reduction ids on the tv in this case.
308 //! Currently only used in translate welford, and this function may
309 //! be refactored or extended if any more use cases appear.
310 void clearReductionIterDomains();
311
312 //! Compute this TensorView relative to a consumer position, -1 will
313 //! compute tensors inline with each other, 0 doesn't share
314 //! any loop nests between the tensors. It's an error when the given
315 //! position is not legally viable. Alternatively, when the mode
316 //! parameter is ComputeAtMode::BestEffort, the position is lowered
317 //! one by one until a valid position is found. When
318 //! ComputeAtMode::MostInlined is given, the position parameter is
319 //! ignored, and the deepest possible position is searched.
320 TensorView* computeAt(
321 TensorView* consumer,
322 int position,
323 ComputeAtMode mode = ComputeAtMode::Standard);
324
325 //! Compute this tensor to consumer, at local position, -1 will compute
326 //! tensors inline with eachother, 0 doesn't share any loop nests between the
327 //! tensors. The mode parameter can be used in the same manner as computeAt.
328 TensorView* computeWith(
329 TensorView* consumer,
330 int position,
331 ComputeAtMode mode = ComputeAtMode::Standard);
332
333 // Split "axis" into 2 axes
334 //! inner_split dictates if the factor section of the split should be inside
335 //! the
336 //! remainer or outside.
337 //! e.g. split(0, 4, inner_split = true) will result in:
338 //! tv[id{extent}] -> tv[id{ceilDiv(extent, factor)}, id{factor}]
339 //! e.g. split(0, 4, inner_split = false) will result in:
340 //! tv[id{extent}] -> tv[id{factor}, id{ceilDiv(extent, factor)}]
341 //!
342 //! When trim_out_of_bounds is true, only the inner domain defined by the
343 //! start and stop positions is split.
344 TensorView* split(
345 int axis,
346 unsigned int factor,
347 bool inner_split = true,
348 bool trim_out_of_bounds = false);
349
350 // Split "axis" into 2 axes where the inner axes is size of "factor"
351 // and outer axis is size axis.size() / factor. Factor can be a symbolic
352 // value instead of constant. This requires setting the symbolic value as an
353 // input, or using a parallel dim from NamedScalar::getParallelDim
354 TensorView* split(
355 int axis,
356 Val* factor,
357 bool inner_split = true,
358 bool trim_out_of_bounds = false);
359
360 // Merge axis_o and axis_i into 1 IterDomain
361 TensorView* merge(int axis_o, int axis_i);
362
363 // Merge axis and axis+1 into 1 IterDomain
364 TensorView* merge(int axis) {
365 return merge(axis, axis + 1);
366 }
367
368 // Reorder axes according to old2new[old_pos] = new_pos
369 TensorView* reorder(const std::unordered_map<int, int>& old2new);
370
371 //! Swizzle indices to improve memory access efficiency.
372 //!
373 //! Swizzle::Transpose is a pattern commonly used to avoid bank
374 //! conflicts in shared memory. It takes two axes and shifts the
375 //! second axis by the first axis as ((axis1 + axis2) % extent). The
376 //! memory type must be Shared.
377 //!
378 //! \input type Swizzle pattern such as transpose.
379 //! \input axes Axes to swizzle
380 TensorView* swizzle(SwizzleType type, const std::vector<int>& axes);
381
382 //! Swizzle the rectangular tile defined by the iterdomains corresponding
383 //! to the 2 given indices.
384 TensorView* swizzle(
385 Swizzle2DType swizzle_type,
386 int x,
387 int y,
388 SwizzleMode swizzle_mode = SwizzleMode::Data);
389
390 // WARNING: rFactor does not return this TensorView, ir returns a new
391 // tensorview consumed by this!
392 //
393 // Take reduction axes out of this domain, and create a new
394 // domain. New domain will be used to create this domain.
395 //
396 // For example:
397 // TV1[I0, R1, R2, I3] = TV0[I0, I1, I2, I3]
398 //
399 // After:
400 // TV1->rfactor({1}), TV1 is transformed to -> TV1[I0, R2, I3]
401 //
402 // The TensorView returned is: TV2[I0, R1, I2, I3]
403 //
404 // The reduction will now beset as:
405 // TV2[I0, R1, I2, I3] = TV0[I0, I1, I2, I3]
406 // TV1[I0, R2, I3] = TV2[I0, R1, I2, I3]
407 //
408 TensorView* rFactor(const std::vector<int>& axes);
409
410 //! Multi-output version of rFactor, semantically similar with
411 //! the reduction version except that the rfactor is done
412 //! for all outputs in a consistent way
413 std::vector<TensorView*> rFactor(
414 const std::vector<int>& axes,
415 const std::vector<TensorView*>& tvs);
416
417 //! Create a TensorView before the original tensor. A common use case is to
418 //! write results into shared memory or registers before moving to global
419 //! memory. Analogous to TVM Cache_Write
420 //!
421 //! @param cache_op: memory operator to use for the inserted op between
422 //! the the data tensor and the cache tensor
423 TensorView* cacheBefore(
424 c10::optional<LoadStoreOpType> cache_op = c10::nullopt);
425
426 //! Create a TensorView after the original tensor. A common use case is to
427 //! read tensor into shared memory or registers. Analogous to TVM Cache_Read
428 //!
429 //! @param cache_op: memory operator to use for the inserted op between
430 //! the the data tensor and the cache tensor
431 TensorView* cacheAfter(
432 c10::optional<LoadStoreOpType> cache_op = c10::nullopt);
433
434 // For a fusion output with other uses, we want to avoid writing to global
435 // memory and then reading the output again. We write to global memory
436 // separately after an operation. We replace this fusion output with the
437 // direct write TensorView.
438 TensorView* cacheFork();
439
440 MemoryType getMemoryType() const {
441 return memory_type_;
442 }
443
444 void setMemoryType(MemoryType mt);
445
446 SwizzleType swizzleType() const {
447 return swizzle_type_;
448 }
449
450 const std::vector<IterDomain*>& axesToSwizzle() const {
451 return axes_to_swizzle_;
452 }
453
454 // Apply double buffering transformation
455 void doubleBuffer();
456
457 // Apply circular buffering transformation
458 void circularBuffer(unsigned int number_of_stage);
459
460 // Returns true if this tensor is double buffered.
461 bool isDoubleBuffered() const {
462 return is_double_buffered_;
463 }
464
465 // Returns true if this tensor is circular buffered.
466 bool isCircularBuffered() const {
467 return is_circular_buffered_;
468 }
469
470 // Returns the depth of circular buffering if applicable.
471 unsigned int circularBufferDepth() const {
472 TORCH_INTERNAL_ASSERT(
473 is_circular_buffered_, toString(), "not circular buffered");
474 return circular_buffer_stage_;
475 }
476
477 //! Transforms the innermost iterdomains according to the given mma swizzle,
478 //! this should be used on the tvs that are either inputs/outputs of an
479 //! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
480 //! have a matching thread swizzle with the mma operand/result.
481 //! More detail on usage see [WarpMmaSwizzler] in scheduler/mma_utils.h .
482 void applyMmaSwizzle(MmaOptions options);
483
484 //! Returns if this tensor view has swizzle operator on its tensor domain.
485 //! This is the temporary flag for indicating that the new swizzle
486 //! implementation is used and will be removed in follow ups.
487 bool hasSwizzleOp() const {
488 return has_swizzle_op_;
489 }
490
491 friend TORCH_CUDA_CU_API TransformPropagator;
492 friend TORCH_CUDA_CU_API MostInlinedTransformPropagator;
493 friend TORCH_CUDA_CU_API TransformReplay;
494 friend TORCH_CUDA_CU_API OptOutMutator;
495 friend class InlineBatchingGuard;
496 friend class ir_utils::TVDomainGuard;
497
498 // Inline the computation of this tensor into its consumer at the given
499 // position. If this tensor is already inlined in a higher position, then this
500 // call is a no-op. If the right most dimensions before `pos` are
501 // broadcasting, then will not inline into these broadcastings. If
502 // best_effort, then will inline into the highest allowed position that is <=
503 // `pos`.
504 void inlineAt(
505 int64_t pos,
506 bool best_effort = false,
507 MaxPosCalculator* calc = nullptr);
508
509 // Update the max producer position of the current tensor. This is required
510 // when we modify producer-consumer relationship of a scheduled tensor, for
511 // example, grouping multiple reductions.
512 void updateMaxProducerPosition();
513
514 protected:
515 void setDomain(TensorDomain* td) {
516 domain_ = td;
517 }
518
519 private:
520 int normalizeAxisPos(int pos) const {
521 if (pos < 0) {
522 pos += nDims();
523 }
524 return pos;
525 }
526
527 //! A helper function to maintain the consistency of schedules of
528 //! multiple outputs wheen doing rfactor on multi-output reduction ops.
529 TensorView* multiOutputRfactorHelper(
530 TensorView* tv,
531 const std::vector<int>& axes);
532
533 private:
534 TensorDomain* domain_ = nullptr;
535 unsigned int compute_at_pos_ = 0;
536 unsigned int max_producer_pos_ = 0;
537 MemoryType memory_type_ = MemoryType::Local;
538 SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
539 std::vector<IterDomain*> axes_to_swizzle_;
540 bool is_double_buffered_ = false;
541
542 //! Indicates if the tensor is circular buffered.
543 bool is_circular_buffered_ = false;
544
545 //! Indicates the circular buffering stage depth if applicable.
546 unsigned int circular_buffer_stage_ = 0;
547
548 // special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only
549 // have one value). This is only used if on an input value, otherwise ignored.
550 // This is important as special handling because these "scalars" should be
551 // type promoted as a tensor, but we want to avoid explicit copying of the
552 // data, so we want to pass the data value as a standard kernel argument
553 // value.
554 bool cpu_scalar_ = false;
555
556 //! Indicates if this tensor view has swizzle operator on its tensor domain.
557 //! This is the temporary flag for indicating that the new swizzle
558 //! implementation is used and will be removed in follow ups.
559 bool has_swizzle_op_ = false;
560};
561
562//! A simple TensorView builder
563//!
564//! Example usage:
565//!
566//! auto tv = TensorViewBuilder()
567//! .ndims(ndims)
568//! .dtype(dtype)
569//! .contiguity(contiguity)
570//! .build();
571//!
572class TORCH_CUDA_CU_API TensorViewBuilder {
573 public:
574 //! Set the number of dimensions of the tensor (default 0, meaning scalar)
575 TensorViewBuilder& ndims(size_t ndims);
576
577 //! Set the data type of the tensor (default DataType::Float)
578 TensorViewBuilder& dtype(DataType dtype);
579
580 //! Set the contiguity information (default non-contiguous)
581 TensorViewBuilder& contiguity(std::vector<bool> contiguity);
582
583 //! Set the shape (default 0 dimensional, ie. scalar)
584 TensorViewBuilder& shape(std::vector<Val*> shape);
585 TensorViewBuilder& shape(const std::vector<int64_t>& shape);
586
587 //! Creates a new TensorView with the specified options
588 TensorView* build() const;
589
590 private:
591 size_t ndims_ = 0;
592 DataType dtype_ = DataType::Float;
593 std::vector<bool> contiguity_;
594 std::vector<Val*> shape_;
595};
596
597} // namespace cuda
598} // namespace fuser
599} // namespace jit
600} // namespace torch
601