1#pragma once
2
3#include <ATen/CPUGeneratorImpl.h>
4#include <ATen/LinalgBackend.h>
5#include <ATen/core/ATenGeneral.h>
6#include <ATen/core/DeprecatedTypeProperties.h>
7#include <ATen/core/Generator.h>
8#include <ATen/core/LegacyTypeDispatch.h>
9#include <ATen/detail/CUDAHooksInterface.h>
10#include <ATen/detail/HIPHooksInterface.h>
11#include <ATen/detail/MPSHooksInterface.h>
12#include <ATen/detail/ORTHooksInterface.h>
13#include <c10/core/QEngine.h>
14#include <c10/core/impl/DeviceGuardImplInterface.h>
15#include <c10/util/CallOnce.h>
16#include <c10/util/Exception.h>
17#include <c10/util/env.h>
18#include <c10/util/irange.h>
19
20#include <cstdint>
21#include <memory>
22#include <mutex>
23
24namespace at {
25
26class Tensor;
27
28enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
29
30class TORCH_API Context {
31 public:
32 Context();
33
34 const Generator& defaultGenerator(Device device) {
35 DeviceType device_type = device.type();
36 initCUDAIfNeeded(device_type);
37 initHIPIfNeeded(device_type);
38 if (device_type == at::kCPU) {
39 return at::detail::getDefaultCPUGenerator();
40 } else if (device_type == at::kCUDA) {
41 return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
42 } else if (device_type == at::kMPS) {
43 return at::detail::getMPSHooks().getDefaultMPSGenerator();
44 } else {
45 AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
46 }
47 }
48 Device getDeviceFromPtr(void* data, DeviceType device_type) {
49 initCUDAIfNeeded(device_type);
50 initHIPIfNeeded(device_type);
51 if (device_type == at::kCPU) {
52 return DeviceType::CPU;
53 } else if (device_type == at::kCUDA) {
54 return at::detail::getCUDAHooks().getDeviceFromPtr(data);
55 } else {
56 AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
57 }
58 }
59 static bool isPinnedPtr(void* data) {
60 return detail::getCUDAHooks().isPinnedPtr(data);
61 }
62 static bool hasOpenMP();
63 static bool hasMKL();
64 static bool hasLAPACK();
65 static bool hasMKLDNN();
66 static bool hasMAGMA() {
67 return detail::getCUDAHooks().hasMAGMA();
68 }
69 static bool hasCUDA() {
70 return detail::getCUDAHooks().hasCUDA();
71 }
72 static bool hasCUDART() {
73 return detail::getCUDAHooks().hasCUDART();
74 }
75 static long versionCUDART() {
76 return detail::getCUDAHooks().versionCUDART();
77 }
78 static bool hasCuDNN() {
79 return detail::getCUDAHooks().hasCuDNN();
80 }
81 static long versionCuDNN() {
82 return detail::getCUDAHooks().versionCuDNN();
83 }
84 static bool hasCuSOLVER() {
85 return detail::getCUDAHooks().hasCuSOLVER();
86 }
87 static bool hasHIP() {
88 return detail::getHIPHooks().hasHIP();
89 }
90 static bool hasMPS() {
91 return detail::getMPSHooks().hasMPS();
92 }
93 static bool hasIPU() {
94 return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU);
95 }
96 static bool hasXLA() {
97 return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
98 }
99 static bool hasLazy() {
100 return c10::impl::hasDeviceGuardImpl(at::DeviceType::Lazy);
101 }
102 static bool hasORT() {
103 return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
104 }
105 // defined in header so that getNonVariableType has ability to inline
106 // call_once check. getNonVariableType is called fairly frequently
107 void lazyInitCUDA() {
108 c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); });
109 }
110 void lazyInitHIP() {
111 c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); });
112 }
113 static const at::cuda::NVRTC& getNVRTC() {
114 return detail::getCUDAHooks().nvrtc();
115 }
116
117 static bool setFlushDenormal(bool on);
118
119 // NB: This method is *purely* whether or not a user requested
120 // that CuDNN was enabled, it doesn't actually say anything about
121 // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
122 // to test this instead
123 bool userEnabledCuDNN() const;
124 void setUserEnabledCuDNN(bool e);
125 bool userEnabledMkldnn() const;
126 void setUserEnabledMkldnn(bool e);
127 bool benchmarkCuDNN() const;
128 void setBenchmarkCuDNN(bool);
129 int benchmarkLimitCuDNN() const;
130 void setBenchmarkLimitCuDNN(int);
131 bool deterministicCuDNN() const;
132 void setDeterministicCuDNN(bool);
133
134 // Note [Disabling Fused SDP Kernels]
135 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
136 // Flash and Memory Efficient SDP kernels are enabled by default.
137 // However, they can be disabled by setting
138 // at::globalContext().setUserEnabledFlashSDP(false) flag.
139 // This is useful for debugging purposes. For example, if you want to
140 // compare the performance of the flash SDP kernels with the unfused
141 // kernel, you can disable the flash SDP kernels. By disabling
142 // the math SDP kernel, you can force your code to use flash kernels.
143 // The math SDP kernel can be disabled by setting
144 // at::globalContext().setUserEnabledMathSDP(false) flag.
145 void setSDPUseFlash(bool);
146 bool userEnabledFlashSDP() const;
147
148 void setSDPUseMemEfficient(bool);
149 bool userEnabledMemEfficientSDP() const;
150
151 void setSDPUseMath(bool);
152 bool userEnabledMathSDP() const;
153
154 at::LinalgBackend linalgPreferredBackend() const;
155 void setLinalgPreferredBackend(at::LinalgBackend);
156
157 // Note [Enabling Deterministic Operations]
158 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
159 // Operations in PyTorch that normally act nondeterministically, but have an
160 // alternate deterministic implementation, should satisfy the following
161 // requirements:
162 //
163 // * Include this comment: "See Note [Enabling Deterministic Operations]"
164 //
165 // * Check the value of `at::globalContext().deterministicAlgorithms()` to
166 // toggle
167 // between nondeterministic and deterministic implementations.
168 //
169 // * Have an entry in the list of PyTorch operations that toggle between
170 // nondeterministic
171 // and deterministic implementations, in the docstring of
172 // `use_deterministic_algorithms()` in torch/__init__.py
173 //
174 // `example_func()` below shows an example of toggling between
175 // nondeterministic and deterministic implementations:
176 //
177 // void example_func() {
178 // // See Note [Enabling Deterministic Operations]
179 // if (at::globalContext().deterministicAlgorithms()) {
180 // example_func_deterministic();
181 // } else {
182 // example_func_nondeterministic();
183 // }
184 // }
185
186 bool deterministicAlgorithms() const;
187 bool deterministicAlgorithmsWarnOnly() const;
188 void setDeterministicAlgorithms(bool, bool);
189
190 // Note [Writing Nondeterministic Operations]
191 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
192 // Operations in PyTorch that act nondeterministically and do not have an
193 // alternate deterministic implementation should satisfy the following
194 // requirements:
195 //
196 // * Include this comment: "See Note [Writing Nondeterministic Operations]"
197 //
198 // * Include a comment explaining why the operation is nondeterministic.
199 //
200 // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
201 // of the time, this should be accomplished by calling
202 // `at::globalContext().alertNotDeterminstic()`. However, if the
203 // nondeterministic behavior is caused by the CuBLAS workspace
204 // configuration in CUDA >= 10.2,
205 // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
206 // called instead (in this case, a comment explaining why the operation is
207 // nondeterministic is not necessary). See below for details on these
208 // methods.
209 //
210 // * Have an entry in the list of nondeterministic PyTorch operations in the
211 // docstring of `use_deterministic_algorithms()` in torch/__init__.py
212 //
213 // * Have a test function in `test/test_torch.py` whose name begins with
214 // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
215 // configuration is the reason for nondeterminism, the operation should be
216 // included in the `test_cublas_config_nondeterministic_alert` test. Any new
217 // tests should ideally follow a pattern similar to the existing ones.
218 //
219 // `example_func()` below shows an example of the comments and error-throwing
220 // code for a nondeterministic operation:
221 //
222 // void example_func() {
223 // // See Note [Writing Nondeterministic Operations]
224 // // Nondeterministic because <reason>
225 // at::globalContext().alertNondeterministic("example_func");
226 // ...
227 // }
228
229 // Throws an error if `Context::deterministicAlgorithms()` is true
230 static void alertNotDeterministic(c10::string_view const& caller);
231
232 // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA
233 // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or
234 // ":4096:8". For more details:
235 // https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
236 void alertCuBLASConfigNotDeterministic() const;
237
238 void setFloat32MatmulPrecision(const std::string& s);
239 bool allowTF32CuDNN() const;
240 void setAllowTF32CuDNN(bool);
241 bool allowTF32CuBLAS() const;
242 void setAllowTF32CuBLAS(bool);
243 Float32MatmulPrecision float32MatmulPrecision() const;
244 void setFloat32MatmulPrecision(Float32MatmulPrecision p);
245 bool allowFP16ReductionCuBLAS() const;
246 void setAllowFP16ReductionCuBLAS(bool);
247 bool allowBF16ReductionCuBLAS() const;
248 void setAllowBF16ReductionCuBLAS(bool);
249 at::QEngine qEngine() const;
250 void setQEngine(at::QEngine e);
251 static const std::vector<at::QEngine>& supportedQEngines();
252 static bool isXNNPACKAvailable();
253 void setCheckSparseTensorInvariants(bool e);
254 bool checkSparseTensorInvariants() const;
255 // This method is used to release the original weight after pre-packing.
256 // It should be called once before loading/running the model.
257 // NB: By default it is set to true for mobile builds.
258 void setReleaseWeightsWhenPrepacking(bool e);
259 bool releaseWeightsWhenPrepacking() const;
260
261 void setDisplayVmapFallbackWarnings(bool enabled);
262 bool areVmapFallbackWarningsEnabled() const;
263
264 void setDefaultMobileCPUAllocator();
265 void unsetDefaultMobileCPUAllocator();
266
267 private:
268 void initCUDAIfNeeded(DeviceType p) {
269 if (p == DeviceType::CUDA) {
270 lazyInitCUDA();
271 }
272 }
273 void initHIPIfNeeded(DeviceType p) {
274 if (p == DeviceType::HIP) {
275 lazyInitHIP();
276 }
277 }
278 static bool checkCuBLASConfigDeterministic();
279 c10::once_flag thc_init;
280 c10::once_flag thh_init;
281 bool enabled_cudnn = true;
282 bool deterministic_cudnn = false;
283 bool _deterministic_algorithms = false;
284 bool _deterministic_algorithms_warn_only = false;
285 bool enabled_flashSDP = true;
286 bool enabled_mem_efficientSDP = true;
287 bool enabled_mathSDP = true;
288#ifdef USE_ROCM
289 bool benchmark_cudnn = true;
290#else
291 bool benchmark_cudnn = false;
292#endif
293 Float32MatmulPrecision float32_matmul_precision =
294 c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
295 ? at::Float32MatmulPrecision::HIGH
296 : at::Float32MatmulPrecision::HIGHEST;
297 int benchmark_limit_cudnn = 10;
298 bool allow_tf32_cudnn = true;
299 bool allow_fp16_reduction_cublas = true;
300 bool allow_bf16_reduction_cublas = true;
301 bool enabled_mkldnn = true;
302 at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
303#ifdef C10_MOBILE
304 bool release_original_weights = true;
305#else
306 bool release_original_weights = false;
307#endif
308 bool display_vmap_fallback_warnings_ = false;
309 c10::optional<at::QEngine> quantized_engine = c10::nullopt;
310 bool enable_sparse_tensor_invariant_checks = false;
311
312 Allocator* prev_allocator_ptr_{nullptr};
313};
314
315TORCH_API Context& globalContext();
316
317static inline void init() {
318 globalContext();
319}
320
321TORCH_API Allocator* getCPUAllocator();
322
323static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
324 Backend p,
325 ScalarType s) {
326 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
327 p, s);
328}
329
330static inline DeprecatedTypeProperties& CPU(ScalarType s) {
331 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
332 Backend::CPU, s);
333}
334
335static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
336 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
337 Backend::CUDA, s);
338}
339
340static inline DeprecatedTypeProperties& HIP(ScalarType s) {
341 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
342 Backend::HIP, s);
343}
344
345static inline DeprecatedTypeProperties& MPS(ScalarType s) {
346 return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
347 Backend::MPS, s);
348}
349
350static inline bool hasCUDA() {
351 return globalContext().hasCUDA();
352}
353
354static inline bool hasHIP() {
355 return globalContext().hasHIP();
356}
357
358static inline bool hasIPU() {
359 return globalContext().hasIPU();
360}
361
362static inline bool hasXLA() {
363 return globalContext().hasXLA();
364}
365
366static inline bool hasMPS() {
367 return globalContext().hasMPS();
368}
369
370static inline bool hasORT() {
371 return globalContext().hasORT();
372}
373
374// Despite its name, this function returns the number of *CUDA* GPUs.
375static inline size_t getNumGPUs() {
376 // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
377 // FUNCTION. If you are interested in interrogating the number of
378 // devices for a specific device type, add that function to the
379 // relevant library (e.g., similar to at::cuda::device_count())
380 if (hasCUDA() && hasHIP()) {
381 throw std::runtime_error(
382 "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
383 "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
384 "means HIP. Rebuild PyTorch with one or the other disabled.");
385 } else if (hasCUDA()) {
386 return detail::getCUDAHooks().getNumGPUs();
387 } else if (hasHIP()) {
388 return detail::getHIPHooks().getNumGPUs();
389 } else {
390 return 0;
391 }
392}
393
394static inline bool hasOpenMP() {
395 return globalContext().hasOpenMP();
396}
397
398static inline bool hasMKL() {
399 return globalContext().hasMKL();
400}
401
402static inline bool hasLAPACK() {
403 return globalContext().hasLAPACK();
404}
405
406static inline bool hasMAGMA() {
407 return globalContext().hasMAGMA();
408}
409
410static inline bool hasMKLDNN() {
411 return globalContext().hasMKLDNN();
412}
413
414static inline void manual_seed(uint64_t seed) {
415 auto gen = globalContext().defaultGenerator(DeviceType::CPU);
416 {
417 // See Note [Acquire lock when using random generators]
418 std::lock_guard<std::mutex> lock(gen.mutex());
419 gen.set_current_seed(seed);
420 }
421 // NB: Sometimes we build with CUDA, but we don't have any GPUs
422 // available. In that case, we must not seed CUDA; it will fail!
423 const auto num_gpus = detail::getCUDAHooks().getNumGPUs();
424 if (hasCUDA() && num_gpus > 0) {
425 for (const auto i : c10::irange(num_gpus)) {
426 auto cuda_gen = globalContext().defaultGenerator(
427 Device(at::kCUDA, static_cast<c10::DeviceIndex>(i)));
428 {
429 // See Note [Acquire lock when using random generators]
430 std::lock_guard<std::mutex> lock(cuda_gen.mutex());
431 cuda_gen.set_current_seed(seed);
432 }
433 }
434 }
435
436 if (hasMPS()) {
437 auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS);
438 // See Note [Acquire lock when using random generators]
439 std::lock_guard<std::mutex> lock(mps_gen.mutex());
440 mps_gen.set_current_seed(seed);
441 }
442}
443
444// When the global flag `allow_tf32` is set to true, cuBLAS handles are
445// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
446// For some operators, such as addmv, TF32 offers no performance improvement
447// but causes precision loss. To help this case, this class implements
448// a RAII guard that can be used to quickly disable TF32 within its scope.
449//
450// Usage:
451// NoTF32Guard disable_tf32;
452struct TORCH_API NoTF32Guard {
453 NoTF32Guard();
454 ~NoTF32Guard();
455 static bool should_disable_tf32();
456
457 private:
458 bool changed = false;
459};
460
461#ifdef USE_ROCM
462struct TORCH_API ROCmBackwardPassGuard {
463 ROCmBackwardPassGuard();
464 ~ROCmBackwardPassGuard();
465 static bool is_backward_pass();
466
467 private:
468 static thread_local bool is_backward_pass_;
469};
470#endif
471
472} // namespace at
473