1#pragma once
2
3#include <ATen/ATen.h>
4#include <c10/util/Exception.h>
5#include <type.h>
6#include <torch/csrc/jit/ir/ir.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13void debugPrint(const c10::TensorTypePtr& type);
14
15bool is_zero_dim_tensor(const std::shared_ptr<c10::TensorType>& tensor_type);
16bool is_zero_sized_tensor(const std::shared_ptr<c10::TensorType>& tensor_type);
17
18bool is_cpu_scalar(const at::Tensor& tensor);
19bool is_cpu_scalar(const c10::TensorType& tensor_type);
20
21// TODO: merge these two
22// check if input is compatible with 32b index mode
23int getCommonDeviceCUDA(const at::ArrayRef<IValue>& inputs);
24KernelIndexMode collectIndexMode(const at::ArrayRef<at::IValue>& inputs);
25
26//! Types of debug print-outs
27//!
28//! These can be set through the `PYTORCH_NVFUSER_DUMP` environment variable
29//!
30enum class DebugDumpOption {
31 FusionIr, //!< Dump the Fusion IR before lowering
32 FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR
33 FusionIrPresched, //!< Dump the Fusion IR before it is scheduled.
34 KernelIr, //!< Dump the compiler Kernel IR
35 ComputeAtMap, //!< Dump the computeAt map
36 CudaKernel, //!< Dump the generated CUDA C++ kernel code
37 CudaFull, //!< Dump the complete CUDA C++ code
38 CudaToFile, //!< Dump CUDA Strings to File
39 DebugInfo, //!< Embed line info and debug info to compiled kernel, and dump
40 //!< the full CUDA C++ code
41 LaunchParam, //!< Dump the Launch parameters of kernel
42 FusionSegments, //!< Dump Segmented Fusion Graph
43 FusionSegmenterLog, //!< Dump Detailed Segmenter Logging
44 FusionArgs, //!< Print the runtime fusion arguments
45 KernelArgs, //!< Print the runtime kernel arguments when launching kernels
46 EffectiveBandwidth, //! Measure kernel performance and print effective
47 //! bandwidth
48 FusionSegmentsDrawing, //!< Dump Segmented Fusion Graph
49 PrintPtxasLog, //!< Print the ptxas verbose log including register usage
50 BufferReuseInfo, //!< Dump the analysis details of local/shared buffer re-use
51 SchedulerDebug, //! Dump scheduler heuristic parameters
52 ParallelDimensions, //!< Dump known parallel dimensions
53 Halo, //! Halo information of tensors
54 PerfDebugVerbose, //! When running kernels, print verbose information
55 //! associated with what's running
56 PythonDefinition, //! Python Frontend Fusion Definition.
57 PythonFrontendDebug, //! Python Frontend debug information.
58 TransformPropagator, //! When running TransformPropagator, print propagation
59 //! path and replay result
60 Cubin, //! Dump compiled CUBIN
61 Ptx, //! Dump compiled PTX
62 BankConflictInfo, //! Dump bank confliction info
63 SyncMap //! RAW dependency info
64};
65
66TORCH_CUDA_CU_API bool isDebugDumpEnabled(DebugDumpOption option);
67
68//! Types of features to disable
69//!
70//! These can be set through the `PYTORCH_NVFUSER_DISABLE` environment variable
71//!
72enum class DisableOption {
73 ArchCheck, //! Disable hardware-specific checks to enable cross arch debug
74 CompileToSass, //! Disable direct compilation to sass so the ptx can be
75 //! examined
76 Fallback, //! Disable fallback
77 Fma, //! Disable FMA instructions
78 IndexHoist, //! Disable index hoisting
79 Nvtx, //! Disable NVTX instrumentation
80 PredicateElimination //! Disable predicate elimination
81};
82
83TORCH_CUDA_CU_API bool isOptionDisabled(DisableOption option);
84
85//! Types of features to enable
86//!
87//! These can be set through the `PYTORCH_NVFUSER_ENABLE` environment variable
88//!
89enum class EnableOption {
90 Complex, //! Enable complex support on python
91 KernelProfile, //! Enable intra-kernel performance profiling
92 LinearDecomposition, //! Enable linear-bias decomposition
93 ConvDecomposition, //! Enable conv-bias decomposition
94};
95
96TORCH_CUDA_CU_API bool isOptionEnabled(EnableOption option);
97
98// Check if fallback path should be used which will dispatch to eagermode if any
99// errors are encountered. Helpful for debugging.
100bool useFallback();
101
102//! Ceil integer division
103constexpr int64_t ceilDiv(int64_t a, int64_t b) {
104 return (a + b - 1) / b;
105}
106
107//! Simple mixin for suppressing copy & move operations, ex:
108//!
109//! class Foo : public NonCopyable {
110//! ...
111//! };
112//!
113class NonCopyable {
114 public:
115 NonCopyable() = default;
116
117 // No copy/move semantics
118 NonCopyable(const NonCopyable&) = delete;
119 NonCopyable& operator=(const NonCopyable&) = delete;
120};
121
122//! A generic root for a hierarchy of polymorphic classes:
123//! - It ensures virtual destructors
124//! - Provides the base->as<Derived>() and node->isA<T>() notation
125class PolymorphicBase {
126 public:
127 virtual ~PolymorphicBase() = default;
128
129 // Replacement for static_cast<T*>(ptr): ptr->as<T>()
130 // (checked in DEBUG builds)
131 template <class T>
132 T* as() {
133#ifdef NDEBUG
134 auto downcast_ptr = static_cast<T*>(this);
135#else
136 auto downcast_ptr = dynamic_cast<T*>(this);
137 TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
138#endif
139 return downcast_ptr;
140 }
141
142 template <class T>
143 const T* as() const {
144#ifdef NDEBUG
145 auto downcast_ptr = static_cast<const T*>(this);
146#else
147 auto downcast_ptr = dynamic_cast<const T*>(this);
148 TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
149#endif
150 return downcast_ptr;
151 }
152
153 //! Check if the runtime time is T (or derived from T)
154 //!
155 //! \note Don't use this for conditional casts. Instead, use:
156 //!
157 //! if (auto t = dynamic_cast<T>(p)) { ... }
158 //!
159 //! instead of:
160 //!
161 //! if (p->isA<T>()) { auto t = p->as<T>(); ... }
162 //!
163 template <class T>
164 bool isA() const {
165 return dynamic_cast<const T*>(this) != nullptr;
166 }
167};
168
169template <class T, std::enable_if_t<std::is_enum<T>::value, bool> = true>
170constexpr unsigned int switch_pair(T t1, T t2) {
171 constexpr unsigned int _WORD_SHIFT = 16;
172 return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2;
173}
174
175std::vector<int64_t> getTensorSizes(TensorTypePtr const& tensor_type);
176
177} // namespace cuda
178} // namespace fuser
179} // namespace jit
180} // namespace torch
181