1 | #pragma once |
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/Optional.h> |
6 | |
7 | #include <c10/macros/Export.h> |
8 | |
9 | #include <array> |
10 | #include <cstdint> |
11 | #include <iostream> |
12 | #include <string> |
13 | #include <unordered_set> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | // https://stackoverflow.com/questions/18837857/cant-use-enum-class-as-unordered-map-key |
21 | struct TypeHash { |
22 | template <typename T> |
23 | std::size_t operator()(T t) const { |
24 | return static_cast<std::size_t>(t); |
25 | } |
26 | }; |
27 | |
28 | // Order of strength |
29 | enum class ValType { |
30 | TensorDomain, |
31 | IterDomain, |
32 | TensorView, |
33 | Scalar, |
34 | NamedScalar, |
35 | Predicate, |
36 | TensorIndex, |
37 | IntPair |
38 | }; |
39 | |
40 | // Manual - The user provides the Bool value. Predicate generation is bypassed. |
41 | // Inline corresponds with PredicateCompute::getInlinePredicate |
42 | // Unswitch corresponds with UnswitchPredicate::get |
43 | // Misaligned - PredicateCompute::getInlinePredicate + Misaligned flag |
44 | // Shift - ShiftPredicateInserter::getShiftPredicate |
45 | // Padding - ShiftPredicateInserter::getPaddingPredicate |
46 | // ReductionWrite - Same as Inline but without reduction axes |
47 | enum class PredicateType { |
48 | Manual, |
49 | Inline, |
50 | Unswitch, |
51 | Vectorize, |
52 | Misaligned, |
53 | Shift, |
54 | Padding, |
55 | ReductionWrite |
56 | }; |
57 | |
58 | // Index type is a convenience type that may be a 64 or 32 signed integer. |
59 | // This is helpful for math on indexing/size when we don't know what the index |
60 | // type might be. This allows us to prevent assuming the welford count must be |
61 | // int64_t which is relatively heavy to carry around. Index will be resolved |
62 | // at compile time with KernelIndexMode. |
63 | enum class DataType { |
64 | Double, |
65 | Float, |
66 | Half, |
67 | Int, |
68 | Index, |
69 | Int32, |
70 | Bool, |
71 | BFloat16, |
72 | ComplexFloat, |
73 | ComplexDouble, |
74 | // Vectorized types, used for reinterpret casting views |
75 | // TODO: add more vectorized types |
76 | Double_2, |
77 | Float_2, |
78 | // Null |
79 | Null |
80 | }; |
81 | |
82 | enum class KernelIndexMode { INT32, INT64 }; |
83 | |
84 | DataType indexModeToDtype(KernelIndexMode index_mode); |
85 | |
86 | // Returns if the datatype is a floating point type |
87 | bool isFloatingPointType(DataType dtype); |
88 | // Returns if the datatype is an boolean type |
89 | bool isIntegralType(DataType dtype); |
90 | // Returns if the datatype is an integer type |
91 | bool isBooleanType(DataType dtype); |
92 | // Returns if the datatype is a complex type |
93 | bool isComplexType(DataType dtype); |
94 | // Returns if the datatype is a vector type |
95 | bool isVectorType(DataType dtype); |
96 | // Return the corresponding vector type |
97 | DataType getVectorType(DataType dtype, size_t vec_size); |
98 | // Return the vector size for the given vector type |
99 | int getVectorSizeFromType(DataType dtype); |
100 | // Return the corresponding type of a vector type |
101 | DataType getTypeFromVectorType(DataType dtype); |
102 | // Return the corresponding scalar of a complex type |
103 | DataType getTypeFromComplexType(DataType dtype); |
104 | // Return if the datatype is supported on the current device |
105 | TORCH_CUDA_CU_API bool isSupportedTypeByDevice(DataType dtype); |
106 | |
107 | enum class ExprType { |
108 | Invalid, |
109 | FullOp, |
110 | ARangeOp, |
111 | EyeOp, |
112 | UnaryOp, |
113 | BinaryOp, |
114 | TernaryOp, |
115 | RNGOp, |
116 | ReductionOp, |
117 | GroupedReductionOp, |
118 | BroadcastOp, |
119 | WelfordOp, |
120 | GroupedWelfordOp, |
121 | MmaOp, |
122 | TransposeOp, |
123 | ExpandOp, |
124 | ShiftOp, |
125 | GatherOp, |
126 | ViewOp, |
127 | LoadStoreOp, |
128 | Split, |
129 | ViewAsScalar, |
130 | Merge, |
131 | Swizzle2D, |
132 | Swizzle2DInt, |
133 | PairSelect, |
134 | Allocate, |
135 | BlockSync, |
136 | GridSync, |
137 | CpAsyncWait, |
138 | CpAsyncCommit, |
139 | InitMagicZero, |
140 | UpdateMagicZero, |
141 | ForLoop, |
142 | IfThenElse, |
143 | GridReduction, |
144 | GroupedGridReduction, |
145 | GridBroadcast, |
146 | GridWelford, |
147 | GroupedGridWelford, |
148 | AllocateFusedReduction |
149 | }; |
150 | |
151 | enum class UnaryOpType { |
152 | Abs, |
153 | Acos, |
154 | Address, |
155 | Asin, |
156 | Atan, |
157 | Atanh, |
158 | Cast, |
159 | Ceil, |
160 | Cos, |
161 | Cosh, |
162 | Exp, |
163 | Expm1, |
164 | Erf, |
165 | Erfc, |
166 | Floor, |
167 | Frac, |
168 | Gelu, |
169 | Imag, |
170 | Silu, |
171 | Lgamma, |
172 | Log, |
173 | Log10, |
174 | Log1p, |
175 | Log2, |
176 | BitCast, |
177 | Neg, |
178 | Real, |
179 | Reciprocal, |
180 | Relu, |
181 | Rsqrt, |
182 | Round, |
183 | Set, |
184 | Sigmoid, |
185 | Sin, |
186 | Sinh, |
187 | Sqrt, |
188 | Tan, |
189 | Tanh, |
190 | Trunc, |
191 | |
192 | // Tools to help debugging |
193 | Print, |
194 | |
195 | // Might be a bitwise operator or boolean operator. |
196 | Not, |
197 | |
198 | // Operators returning boolean values |
199 | IsFinite, |
200 | IsInf, |
201 | IsNan, |
202 | IsNegInf, |
203 | IsPosInf, |
204 | IsReal, |
205 | }; |
206 | |
207 | // Primarily for Not, which could be Not a boolean, or a bitwise not. |
208 | bool alsoBooleanOperator(const UnaryOpType uopt); |
209 | |
210 | // TODO: Order of this list is important as it affects type promotion. it's not |
211 | // in the right order now. |
212 | enum class BinaryOpType { |
213 | // Math Ops |
214 | Add, |
215 | Atan2, |
216 | Div, |
217 | Fmod, |
218 | Max, |
219 | Min, |
220 | Mul, |
221 | Pow, |
222 | Remainder, |
223 | Sub, |
224 | // TypeAs, |
225 | |
226 | // Integer output ops. If changing modify isIntegerOp |
227 | Mod, |
228 | CeilDiv, |
229 | Lshift, |
230 | Rshift, |
231 | |
232 | // Logical Ops |
233 | // Int operations, leave position of Mod as first logical op see |
234 | // isLogicalOp(BinaryOpType bopt) |
235 | Eq, |
236 | GE, |
237 | GT, |
238 | LE, |
239 | LT, |
240 | NE, |
241 | |
242 | // Maybe bitwise or boolean op, leave position of and as first bool/int |
243 | // op. These are ops that have different operators based on output type. See |
244 | // is boolean op. These ops also don't work on floating point inputs. |
245 | And, |
246 | Or, |
247 | Xor |
248 | }; |
249 | |
250 | enum class RNGOpType { |
251 | Uniform, // Uniform in [0, 1) |
252 | UniformRange, // Uniform in [low, high] |
253 | }; |
254 | |
255 | // Return if output of operator should be a boolean |
256 | bool isIntegerOp(const BinaryOpType bopt); |
257 | |
258 | // Return if output of operator should be a boolean |
259 | bool isLogicalOp(const BinaryOpType bopt); |
260 | |
261 | // Operations that could be a bitwise operation or a boolean operation depending |
262 | // on input, for example bitwise_and is also used for boolean and in the jit |
263 | bool alsoBooleanOperator(const BinaryOpType bopt); |
264 | |
265 | enum class TernaryOpType { Clamp, Lerp, Threshold, Where }; |
266 | |
267 | enum class ParallelType { |
268 | BIDz, |
269 | BIDy, |
270 | BIDx, |
271 | TIDz, |
272 | TIDy, |
273 | TIDx, |
274 | Vectorize, |
275 | MisalignedVectorize, |
276 | Unroll, |
277 | Unswitch, |
278 | Mma, |
279 | Group, |
280 | Serial |
281 | }; |
282 | |
283 | TORCH_CUDA_CU_API std::unordered_set<ParallelType> allParallelTypesExcept( |
284 | const std::unordered_set<ParallelType>& except); |
285 | |
286 | static constexpr std::array<ParallelType, 6> kParallelTypeThreads = { |
287 | ParallelType::BIDx, |
288 | ParallelType::BIDy, |
289 | ParallelType::BIDz, |
290 | ParallelType::TIDx, |
291 | ParallelType::TIDy, |
292 | ParallelType::TIDz}; |
293 | |
294 | static constexpr std::array<ParallelType, 3> kParallelTypeBIDs = { |
295 | ParallelType::BIDx, |
296 | ParallelType::BIDy, |
297 | ParallelType::BIDz}; |
298 | |
299 | static constexpr std::array<ParallelType, 3> kParallelTypeTIDs = { |
300 | ParallelType::TIDx, |
301 | ParallelType::TIDy, |
302 | ParallelType::TIDz}; |
303 | |
304 | enum class MemoryType { Local, Shared, Global }; |
305 | |
306 | // sometimes broadcasted tensors may be inputed in the kernel with an explicit 1 |
307 | // size. If that size is there, we need to account that there's also a stride |
308 | // there, even if the stride = 0. If we don't account for that stride when |
309 | // accessing a tensor like: [b2{1}, i0, i1] we would linearize the access like: |
310 | // [i0*stride[0] + i1*stride[1]] when it should be: [i0*stride[1] + |
311 | // i1*stride[2]]. Broadcasts that translate to a physical memory dim we consider |
312 | // "with stride", Broadcasts only through our broadcast op we consider "without |
313 | // stride" |
314 | enum class IterType { |
315 | Iteration, |
316 | Reduction, |
317 | Broadcast, |
318 | Gather, |
319 | Stride, |
320 | VectorComponent |
321 | }; |
322 | |
323 | enum class SwizzleType { NoSwizzle, Transpose }; |
324 | |
325 | // Used for Iteration Domain mapping modes in ComputeAtMap |
326 | enum class IdMappingMode { PERMISSIVE, EXACT, LOOP }; |
327 | |
328 | static constexpr std::array<IdMappingMode, 3> kIdMappingModes = { |
329 | IdMappingMode::PERMISSIVE, |
330 | IdMappingMode::EXACT, |
331 | IdMappingMode::LOOP}; |
332 | |
333 | // Used to annotate the special memory intrinsics that a loadstore |
334 | // op will be lowered to. |
335 | enum class LoadStoreOpType { LdMatrix, LdMatrixTranspose, CpAsync }; |
336 | |
337 | // Used to label what part of the double buffered iterdomain |
338 | // a for loop is materializing. |
339 | enum class DoubleBufferLoopStage { NotApplicable, Prolog, Main, Epilog }; |
340 | |
341 | //! Supported swizzle types, |
342 | //! corresponds to swizzles functions on the runtime cuda |
343 | //! naming it swizzle_2d to reserve the options to have a swizzle_1d. |
344 | //! |
345 | //! TODO: unify with existing swizzle logic, currently |
346 | //! doesn't have the same type. |
347 | enum class Swizzle2DType { NoSwizzle = 0, ZShape, Transpose, XOR, Scatter }; |
348 | |
349 | //! Modes of swizzle, see [Note on swizzle mode]. |
350 | enum class SwizzleMode { NoSwizzle = 0, Data, Loop }; |
351 | |
352 | // Returns if function needs an f suffix on the operator when operating on a |
353 | // float value i.e. sin->sinf |
354 | bool needFloatSuffix(UnaryOpType t); |
355 | bool needFloatSuffix(BinaryOpType t); |
356 | bool needFloatSuffix(RNGOpType t); |
357 | |
358 | ValType promote_type(const ValType& t1, const ValType& t2); |
359 | DataType promote_type(const DataType& t1, const DataType& t2); |
360 | |
361 | // If type cannot be found (i.e. codegen does not support provided type) returns |
362 | // DataType::Null |
363 | TORCH_CUDA_CU_API DataType aten_to_data_type(const at::ScalarType& scalar_type); |
364 | TORCH_CUDA_CU_API at::ScalarType data_type_to_aten(const DataType& data_type); |
365 | |
366 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ValType); |
367 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const PredicateType); |
368 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const DataType); |
369 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ExprType); |
370 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const UnaryOpType); |
371 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const BinaryOpType); |
372 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const TernaryOpType); |
373 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const RNGOpType); |
374 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const ParallelType); |
375 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const MemoryType); |
376 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IterType); |
377 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const IdMappingMode); |
378 | TORCH_CUDA_CU_API std::ostream& operator<<( |
379 | std::ostream&, |
380 | const LoadStoreOpType); |
381 | TORCH_CUDA_CU_API std::ostream& operator<<( |
382 | std::ostream&, |
383 | const DoubleBufferLoopStage); |
384 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const Swizzle2DType&); |
385 | TORCH_CUDA_CU_API std::ostream& operator<<(std::ostream&, const SwizzleMode&); |
386 | |
387 | std::string stringifyBooleanOp(const UnaryOpType); |
388 | std::string stringifyBooleanOp(const BinaryOpType); |
389 | |
390 | std::string stringifyThreadSize(const ParallelType); |
391 | std::string stringifyThread(const ParallelType); |
392 | std::string typePrefix(const DataType); |
393 | |
394 | // TODO: ThreadDim should be BlockDim and BlockDim should be GridDim |
395 | // Returns if parallel type is TID[x, y, z] |
396 | TORCH_CUDA_CU_API bool isParallelTypeThreadDim(ParallelType); |
397 | // Returns if parallel type is BID[x, y, z] |
398 | TORCH_CUDA_CU_API bool isParallelTypeBlockDim(ParallelType); |
399 | // Returns if parallel type is a grid or block parallelization dimension |
400 | TORCH_CUDA_CU_API bool isParallelTypeThread(ParallelType); |
401 | |
402 | TORCH_CUDA_CU_API bool isParallelTypeVectorize(ParallelType); |
403 | |
404 | TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const UnaryOpType); |
405 | TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const BinaryOpType); |
406 | TORCH_CUDA_CU_API c10::optional<std::string> inline_op_str(const RNGOpType); |
407 | TORCH_CUDA_CU_API c10::optional<std::string> integer_op_str(const BinaryOpType); |
408 | TORCH_CUDA_CU_API c10::optional<std::string> bool_op_str(const BinaryOpType); |
409 | |
410 | TORCH_CUDA_CU_API c10::optional<std::string> cast_func_str( |
411 | const std::pair<DataType, DataType>&); |
412 | |
413 | TORCH_CUDA_CU_API size_t dataTypeSize(DataType type); |
414 | |
415 | // If the index type is known it will be automatically used here |
416 | TORCH_CUDA_CU_API size_t dataTypeSize(DataType type, DataType index_type); |
417 | |
418 | enum class LaunchConfigType { |
419 | Compatible, |
420 | SharedMemory, |
421 | BIDz, |
422 | BIDy, |
423 | BIDx, |
424 | TIDz, |
425 | TIDy, |
426 | TIDx |
427 | }; |
428 | |
429 | const char* const kMagicZeroName = "nvfuser_zero" ; |
430 | |
431 | //! Maximum number of reductions that can be grouped together. The |
432 | //! limit can be increased by extending struct Tuple define in tuple.cu. |
433 | static constexpr int kMaxNumGroupedReductions = 8; |
434 | |
435 | } // namespace cuda |
436 | } // namespace fuser |
437 | } // namespace jit |
438 | } // namespace torch |
439 | |