1#pragma once
2
3#include <c10/core/ScalarType.h>
4#include <c10/util/Exception.h>
5#include <c10/util/Optional.h>
6
7#include <c10/macros/Export.h>
8
9#include <array>
10#include <cstdint>
11#include <iostream>
12#include <string>
13#include <unordered_set>
14
15namespace torch {
16namespace jit {
17namespace fuser {
18namespace cuda {
19
20// https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key
21struct TypeHash {
22 template <typename T>
23 std::size_t operator()(T t) const {
24 return static_cast<std::size_t>(t);
25 }
26};
27
28// Order of strength
29enum class ValType {
30 TensorDomain,
31 IterDomain,
32 TensorView,
33 Scalar,
34 NamedScalar,
35 Predicate,
36 TensorIndex,
37 IntPair
38};
39
40// Manual - The user provides the Bool value. Predicate generation is bypassed.
41// Inline corresponds with PredicateCompute::getInlinePredicate
42// Unswitch corresponds with UnswitchPredicate::get
43// Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag
44// Shift - ShiftPredicateInserter::getShiftPredicate
45// Padding - ShiftPredicateInserter::getPaddingPredicate
46// ReductionWrite - Same as Inline but without reduction axes
47enum class PredicateType {
48 Manual,
49 Inline,
50 Unswitch,
51 Vectorize,
52 Misaligned,
53 Shift,
54 Padding,
55 ReductionWrite
56};
57
58// Index type is a convenience type that may be a 64 or 32 signed integer.
59// This is helpful for math on indexing/size when we don't know what the index
60// type might be. This allows us to prevent assuming the welford count must be
61// int64_t which is relatively heavy to carry around. Index will be resolved
62// at compile time with KernelIndexMode.
63enum class DataType {
64 Double,
65 Float,
66 Half,
67 Int,
68 Index,
69 Int32,
70 Bool,
71 BFloat16,
72 ComplexFloat,
73 ComplexDouble,
74 // Vectorized types, used for reinterpret casting views
75 // TODO: add more vectorized types
76 Double_2,
77 Float_2,
78 // Null
79 Null
80};
81
82enum class KernelIndexMode { INT32, INT64 };
83
84DataType indexModeToDtype(KernelIndexMode index_mode);
85
86// Returns if the datatype is a floating point type
87bool isFloatingPointType(DataType dtype);
88// Returns if the datatype is an boolean type
89bool isIntegralType(DataType dtype);
90// Returns if the datatype is an integer type
91bool isBooleanType(DataType dtype);
92// Returns if the datatype is a complex type
93bool isComplexType(DataType dtype);
94// Returns if the datatype is a vector type
95bool isVectorType(DataType dtype);
96// Return the corresponding vector type
97DataType getVectorType(DataType dtype, size_t vec_size);
98// Return the vector size for the given vector type
99int getVectorSizeFromType(DataType dtype);
100// Return the corresponding type of a vector type
101DataType getTypeFromVectorType(DataType dtype);
102// Return the corresponding scalar of a complex type
103DataType getTypeFromComplexType(DataType dtype);
104// Return if the datatype is supported on the current device
105TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype);
106
107enum class ExprType {
108 Invalid,
109 FullOp,
110 ARangeOp,
111 EyeOp,
112 UnaryOp,
113 BinaryOp,
114 TernaryOp,
115 RNGOp,
116 ReductionOp,
117 GroupedReductionOp,
118 BroadcastOp,
119 WelfordOp,
120 GroupedWelfordOp,
121 MmaOp,
122 TransposeOp,
123 ExpandOp,
124 ShiftOp,
125 GatherOp,
126 ViewOp,
127 LoadStoreOp,
128 Split,
129 ViewAsScalar,
130 Merge,
131 Swizzle2D,
132 Swizzle2DInt,
133 PairSelect,
134 Allocate,
135 BlockSync,
136 GridSync,
137 CpAsyncWait,
138 CpAsyncCommit,
139 InitMagicZero,
140 UpdateMagicZero,
141 ForLoop,
142 IfThenElse,
143 GridReduction,
144 GroupedGridReduction,
145 GridBroadcast,
146 GridWelford,
147 GroupedGridWelford,
148 AllocateFusedReduction
149};
150
151enum class UnaryOpType {
152 Abs,
153 Acos,
154 Address,
155 Asin,
156 Atan,
157 Atanh,
158 Cast,
159 Ceil,
160 Cos,
161 Cosh,
162 Exp,
163 Expm1,
164 Erf,
165 Erfc,
166 Floor,
167 Frac,
168 Gelu,
169 Imag,
170 Silu,
171 Lgamma,
172 Log,
173 Log10,
174 Log1p,
175 Log2,
176 BitCast,
177 Neg,
178 Real,
179 Reciprocal,
180 Relu,
181 Rsqrt,
182 Round,
183 Set,
184 Sigmoid,
185 Sin,
186 Sinh,
187 Sqrt,
188 Tan,
189 Tanh,
190 Trunc,
191
192 // Tools to help debugging
193 Print,
194
195 // Might be a bitwise operator or boolean operator.
196 Not,
197
198 // Operators returning boolean values
199 IsFinite,
200 IsInf,
201 IsNan,
202 IsNegInf,
203 IsPosInf,
204 IsReal,
205};
206
207// Primarily for Not, which could be Not a boolean, or a bitwise not.
208bool alsoBooleanOperator(const UnaryOpType uopt);
209
210// TODO: Order of this list is important as it affects type promotion. it's not
211// in the right order now.
212enum class BinaryOpType {
213 // Math Ops
214 Add,
215 Atan2,
216 Div,
217 Fmod,
218 Max,
219 Min,
220 Mul,
221 Pow,
222 Remainder,
223 Sub,
224 // TypeAs,
225
226 // Integer output ops. If changing modify isIntegerOp
227 Mod,
228 CeilDiv,
229 Lshift,
230 Rshift,
231
232 // Logical Ops
233 // Int operations, leave position of Mod as first logical op see
234 // isLogicalOp(BinaryOpType bopt)
235 Eq,
236 GE,
237 GT,
238 LE,
239 LT,
240 NE,
241
242 // Maybe bitwise or boolean op, leave position of and as first bool/int
243 // op. These are ops that have different operators based on output type. See
244 // is boolean op. These ops also don't work on floating point inputs.
245 And,
246 Or,
247 Xor
248};
249
250enum class RNGOpType {
251 Uniform, // Uniform in [0, 1)
252 UniformRange, // Uniform in [low, high]
253};
254
255// Return if output of operator should be a boolean
256bool isIntegerOp(const BinaryOpType bopt);
257
258// Return if output of operator should be a boolean
259bool isLogicalOp(const BinaryOpType bopt);
260
261// Operations that could be a bitwise operation or a boolean operation depending
262// on input, for example bitwise_and is also used for boolean and in the jit
263bool alsoBooleanOperator(const BinaryOpType bopt);
264
265enum class TernaryOpType { Clamp, Lerp, Threshold, Where };
266
267enum class ParallelType {
268 BIDz,
269 BIDy,
270 BIDx,
271 TIDz,
272 TIDy,
273 TIDx,
274 Vectorize,
275 MisalignedVectorize,
276 Unroll,
277 Unswitch,
278 Mma,
279 Group,
280 Serial
281};
282
283TORCH_CUDA_CU_API std::unordered_set<ParallelType> allParallelTypesExcept(
284 const std::unordered_set<ParallelType>& except);
285
286static constexpr std::array<ParallelType, 6> kParallelTypeThreads = {
287 ParallelType::BIDx,
288 ParallelType::BIDy,
289 ParallelType::BIDz,
290 ParallelType::TIDx,
291 ParallelType::TIDy,
292 ParallelType::TIDz};
293
294static constexpr std::array<ParallelType, 3> kParallelTypeBIDs = {
295 ParallelType::BIDx,
296 ParallelType::BIDy,
297 ParallelType::BIDz};
298
299static constexpr std::array<ParallelType, 3> kParallelTypeTIDs = {
300 ParallelType::TIDx,
301 ParallelType::TIDy,
302 ParallelType::TIDz};
303
304enum class MemoryType { Local, Shared, Global };
305
306// sometimes broadcasted tensors may be inputed in the kernel with an explicit 1
307// size. If that size is there, we need to account that there's also a stride
308// there, even if the stride = 0. If we don't account for that stride when
309// accessing a tensor like: [b2{1}, i0, i1] we would linearize the access like:
310// [i0*stride[0] + i1*stride[1]] when it should be: [i0*stride[1] +
311// i1*stride[2]]. Broadcasts that translate to a physical memory dim we consider
312// "with stride", Broadcasts only through our broadcast op we consider "without
313// stride"
314enum class IterType {
315 Iteration,
316 Reduction,
317 Broadcast,
318 Gather,
319 Stride,
320 VectorComponent
321};
322
323enum class SwizzleType { NoSwizzle, Transpose };
324
325// Used for Iteration Domain mapping modes in ComputeAtMap
326enum class IdMappingMode { PERMISSIVE, EXACT, LOOP };
327
328static constexpr std::array<IdMappingMode, 3> kIdMappingModes = {
329 IdMappingMode::PERMISSIVE,
330 IdMappingMode::EXACT,
331 IdMappingMode::LOOP};
332
333// Used to annotate the special memory intrinsics that a loadstore
334// op will be lowered to.
335enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync };
336
337// Used to label what part of the double buffered iterdomain
338// a for loop is materializing.
339enum class DoubleBufferLoopStage { NotApplicable, Prolog, Main, Epilog };
340
341//! Supported swizzle types,
342//! corresponds to swizzles functions on the runtime cuda
343//! naming it swizzle_2d to reserve the options to have a swizzle_1d.
344//!
345//! TODO: unify with existing swizzle logic, currently
346//! doesn't have the same type.
347enum class Swizzle2DType { NoSwizzle = 0, ZShape, Transpose, XOR, Scatter };
348
349//! Modes of swizzle, see [Note on swizzle mode].
350enum class SwizzleMode { NoSwizzle = 0, Data, Loop };
351
352// Returns if function needs an f suffix on the operator when operating on a
353// float value i.e. sin->sinf
354bool needFloatSuffix(UnaryOpType t);
355bool needFloatSuffix(BinaryOpType t);
356bool needFloatSuffix(RNGOpType t);
357
358ValType promote_type(const ValType& t1, const ValType& t2);
359DataType promote_type(const DataType& t1, const DataType& t2);
360
361// If type cannot be found (i.e. codegen does not support provided type) returns
362// DataType::Null
363TORCH_CUDA_CU_API DataType aten_to_data_type(const at::ScalarType& scalar_type);
364TORCH_CUDA_CU_API at::ScalarType data_type_to_aten(const DataType& data_type);
365
366TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ValType);
367TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const PredicateType);
368TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const DataType);
369TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ExprType);
370TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const UnaryOpType);
371TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const BinaryOpType);
372TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const TernaryOpType);
373TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const RNGOpType);
374TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ParallelType);
375TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType);
376TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType);
377TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IdMappingMode);
378TORCH_CUDA_CU_API std::ostream& operator<<(
379 std::ostream&,
380 const LoadStoreOpType);
381TORCH_CUDA_CU_API std::ostream& operator<<(
382 std::ostream&,
383 const DoubleBufferLoopStage);
384TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const Swizzle2DType&);
385TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const SwizzleMode&);
386
387std::string stringifyBooleanOp(const UnaryOpType);
388std::string stringifyBooleanOp(const BinaryOpType);
389
390std::string stringifyThreadSize(const ParallelType);
391std::string stringifyThread(const ParallelType);
392std::string typePrefix(const DataType);
393
394// TODO: ThreadDim should be BlockDim and BlockDim should be GridDim
395// Returns if parallel type is TID[x, y, z]
396TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType);
397// Returns if parallel type is BID[x, y, z]
398TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType);
399// Returns if parallel type is a grid or block parallelization dimension
400TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType);
401
402TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType);
403
404TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const UnaryOpType);
405TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const BinaryOpType);
406TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const RNGOpType);
407TORCH_CUDA_CU_API c10::optional<std::string> integer_op_str(const BinaryOpType);
408TORCH_CUDA_CU_API c10::optional<std::string> bool_op_str(const BinaryOpType);
409
410TORCH_CUDA_CU_API c10::optional<std::string> cast_func_str(
411 const std::pair<DataType, DataType>&);
412
413TORCH_CUDA_CU_API size_t dataTypeSize(DataType type);
414
415// If the index type is known it will be automatically used here
416TORCH_CUDA_CU_API size_t dataTypeSize(DataType type, DataType index_type);
417
418enum class LaunchConfigType {
419 Compatible,
420 SharedMemory,
421 BIDz,
422 BIDy,
423 BIDx,
424 TIDz,
425 TIDy,
426 TIDx
427};
428
429const char* const kMagicZeroName = "nvfuser_zero";
430
431//! Maximum number of reductions that can be grouped together. The
432//! limit can be increased by extending struct Tuple define in tuple.cu.
433static constexpr int kMaxNumGroupedReductions = 8;
434
435} // namespace cuda
436} // namespace fuser
437} // namespace jit
438} // namespace torch
439