1 | #pragma once |
2 | |
3 | #include <ATen/CPUGeneratorImpl.h> |
4 | #include <ATen/LinalgBackend.h> |
5 | #include <ATen/core/ATenGeneral.h> |
6 | #include <ATen/core/DeprecatedTypeProperties.h> |
7 | #include <ATen/core/Generator.h> |
8 | #include <ATen/core/LegacyTypeDispatch.h> |
9 | #include <ATen/detail/CUDAHooksInterface.h> |
10 | #include <ATen/detail/HIPHooksInterface.h> |
11 | #include <ATen/detail/MPSHooksInterface.h> |
12 | #include <ATen/detail/ORTHooksInterface.h> |
13 | #include <c10/core/QEngine.h> |
14 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
15 | #include <c10/util/CallOnce.h> |
16 | #include <c10/util/Exception.h> |
17 | #include <c10/util/env.h> |
18 | #include <c10/util/irange.h> |
19 | |
20 | #include <cstdint> |
21 | #include <memory> |
22 | #include <mutex> |
23 | |
24 | namespace at { |
25 | |
26 | class Tensor; |
27 | |
28 | enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM }; |
29 | |
30 | class TORCH_API Context { |
31 | public: |
32 | Context(); |
33 | |
34 | const Generator& defaultGenerator(Device device) { |
35 | DeviceType device_type = device.type(); |
36 | initCUDAIfNeeded(device_type); |
37 | initHIPIfNeeded(device_type); |
38 | if (device_type == at::kCPU) { |
39 | return at::detail::getDefaultCPUGenerator(); |
40 | } else if (device_type == at::kCUDA) { |
41 | return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index()); |
42 | } else if (device_type == at::kMPS) { |
43 | return at::detail::getMPSHooks().getDefaultMPSGenerator(); |
44 | } else { |
45 | AT_ERROR(DeviceTypeName(device_type), " device type not enabled." ); |
46 | } |
47 | } |
48 | Device getDeviceFromPtr(void* data, DeviceType device_type) { |
49 | initCUDAIfNeeded(device_type); |
50 | initHIPIfNeeded(device_type); |
51 | if (device_type == at::kCPU) { |
52 | return DeviceType::CPU; |
53 | } else if (device_type == at::kCUDA) { |
54 | return at::detail::getCUDAHooks().getDeviceFromPtr(data); |
55 | } else { |
56 | AT_ERROR(DeviceTypeName(device_type), " device type not enabled." ); |
57 | } |
58 | } |
59 | static bool isPinnedPtr(void* data) { |
60 | return detail::getCUDAHooks().isPinnedPtr(data); |
61 | } |
62 | static bool hasOpenMP(); |
63 | static bool hasMKL(); |
64 | static bool hasLAPACK(); |
65 | static bool hasMKLDNN(); |
66 | static bool hasMAGMA() { |
67 | return detail::getCUDAHooks().hasMAGMA(); |
68 | } |
69 | static bool hasCUDA() { |
70 | return detail::getCUDAHooks().hasCUDA(); |
71 | } |
72 | static bool hasCUDART() { |
73 | return detail::getCUDAHooks().hasCUDART(); |
74 | } |
75 | static long versionCUDART() { |
76 | return detail::getCUDAHooks().versionCUDART(); |
77 | } |
78 | static bool hasCuDNN() { |
79 | return detail::getCUDAHooks().hasCuDNN(); |
80 | } |
81 | static long versionCuDNN() { |
82 | return detail::getCUDAHooks().versionCuDNN(); |
83 | } |
84 | static bool hasCuSOLVER() { |
85 | return detail::getCUDAHooks().hasCuSOLVER(); |
86 | } |
87 | static bool hasHIP() { |
88 | return detail::getHIPHooks().hasHIP(); |
89 | } |
90 | static bool hasMPS() { |
91 | return detail::getMPSHooks().hasMPS(); |
92 | } |
93 | static bool hasIPU() { |
94 | return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU); |
95 | } |
96 | static bool hasXLA() { |
97 | return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA); |
98 | } |
99 | static bool hasLazy() { |
100 | return c10::impl::hasDeviceGuardImpl(at::DeviceType::Lazy); |
101 | } |
102 | static bool hasORT() { |
103 | return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT); |
104 | } |
105 | // defined in header so that getNonVariableType has ability to inline |
106 | // call_once check. getNonVariableType is called fairly frequently |
107 | void lazyInitCUDA() { |
108 | c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); |
109 | } |
110 | void lazyInitHIP() { |
111 | c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); |
112 | } |
113 | static const at::cuda::NVRTC& getNVRTC() { |
114 | return detail::getCUDAHooks().nvrtc(); |
115 | } |
116 | |
117 | static bool setFlushDenormal(bool on); |
118 | |
119 | // NB: This method is *purely* whether or not a user requested |
120 | // that CuDNN was enabled, it doesn't actually say anything about |
121 | // whether or not CuDNN is actually usable. Use cudnn_is_acceptable |
122 | // to test this instead |
123 | bool userEnabledCuDNN() const; |
124 | void setUserEnabledCuDNN(bool e); |
125 | bool userEnabledMkldnn() const; |
126 | void setUserEnabledMkldnn(bool e); |
127 | bool benchmarkCuDNN() const; |
128 | void setBenchmarkCuDNN(bool); |
129 | int benchmarkLimitCuDNN() const; |
130 | void setBenchmarkLimitCuDNN(int); |
131 | bool deterministicCuDNN() const; |
132 | void setDeterministicCuDNN(bool); |
133 | |
134 | // Note [Disabling Fused SDP Kernels] |
135 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
136 | // Flash and Memory Efficient SDP kernels are enabled by default. |
137 | // However, they can be disabled by setting |
138 | // at::globalContext().setUserEnabledFlashSDP(false) flag. |
139 | // This is useful for debugging purposes. For example, if you want to |
140 | // compare the performance of the flash SDP kernels with the unfused |
141 | // kernel, you can disable the flash SDP kernels. By disabling |
142 | // the math SDP kernel, you can force your code to use flash kernels. |
143 | // The math SDP kernel can be disabled by setting |
144 | // at::globalContext().setUserEnabledMathSDP(false) flag. |
145 | void setSDPUseFlash(bool); |
146 | bool userEnabledFlashSDP() const; |
147 | |
148 | void setSDPUseMemEfficient(bool); |
149 | bool userEnabledMemEfficientSDP() const; |
150 | |
151 | void setSDPUseMath(bool); |
152 | bool userEnabledMathSDP() const; |
153 | |
154 | at::LinalgBackend linalgPreferredBackend() const; |
155 | void setLinalgPreferredBackend(at::LinalgBackend); |
156 | |
157 | // Note [Enabling Deterministic Operations] |
158 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
159 | // Operations in PyTorch that normally act nondeterministically, but have an |
160 | // alternate deterministic implementation, should satisfy the following |
161 | // requirements: |
162 | // |
163 | // * Include this comment: "See Note [Enabling Deterministic Operations]" |
164 | // |
165 | // * Check the value of `at::globalContext().deterministicAlgorithms()` to |
166 | // toggle |
167 | // between nondeterministic and deterministic implementations. |
168 | // |
169 | // * Have an entry in the list of PyTorch operations that toggle between |
170 | // nondeterministic |
171 | // and deterministic implementations, in the docstring of |
172 | // `use_deterministic_algorithms()` in torch/__init__.py |
173 | // |
174 | // `example_func()` below shows an example of toggling between |
175 | // nondeterministic and deterministic implementations: |
176 | // |
177 | // void example_func() { |
178 | // // See Note [Enabling Deterministic Operations] |
179 | // if (at::globalContext().deterministicAlgorithms()) { |
180 | // example_func_deterministic(); |
181 | // } else { |
182 | // example_func_nondeterministic(); |
183 | // } |
184 | // } |
185 | |
186 | bool deterministicAlgorithms() const; |
187 | bool deterministicAlgorithmsWarnOnly() const; |
188 | void setDeterministicAlgorithms(bool, bool); |
189 | |
190 | // Note [Writing Nondeterministic Operations] |
191 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
192 | // Operations in PyTorch that act nondeterministically and do not have an |
193 | // alternate deterministic implementation should satisfy the following |
194 | // requirements: |
195 | // |
196 | // * Include this comment: "See Note [Writing Nondeterministic Operations]" |
197 | // |
198 | // * Include a comment explaining why the operation is nondeterministic. |
199 | // |
200 | // * Throw an error when `Context::deterministicAlgorithms()` is true. Most |
201 | // of the time, this should be accomplished by calling |
202 | // `at::globalContext().alertNotDeterminstic()`. However, if the |
203 | // nondeterministic behavior is caused by the CuBLAS workspace |
204 | // configuration in CUDA >= 10.2, |
205 | // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be |
206 | // called instead (in this case, a comment explaining why the operation is |
207 | // nondeterministic is not necessary). See below for details on these |
208 | // methods. |
209 | // |
210 | // * Have an entry in the list of nondeterministic PyTorch operations in the |
211 | // docstring of `use_deterministic_algorithms()` in torch/__init__.py |
212 | // |
213 | // * Have a test function in `test/test_torch.py` whose name begins with |
214 | // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace |
215 | // configuration is the reason for nondeterminism, the operation should be |
216 | // included in the `test_cublas_config_nondeterministic_alert` test. Any new |
217 | // tests should ideally follow a pattern similar to the existing ones. |
218 | // |
219 | // `example_func()` below shows an example of the comments and error-throwing |
220 | // code for a nondeterministic operation: |
221 | // |
222 | // void example_func() { |
223 | // // See Note [Writing Nondeterministic Operations] |
224 | // // Nondeterministic because <reason> |
225 | // at::globalContext().alertNondeterministic("example_func"); |
226 | // ... |
227 | // } |
228 | |
229 | // Throws an error if `Context::deterministicAlgorithms()` is true |
230 | static void alertNotDeterministic(c10::string_view const& caller); |
231 | |
232 | // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA |
233 | // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or |
234 | // ":4096:8". For more details: |
235 | // https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility |
236 | void alertCuBLASConfigNotDeterministic() const; |
237 | |
238 | void setFloat32MatmulPrecision(const std::string& s); |
239 | bool allowTF32CuDNN() const; |
240 | void setAllowTF32CuDNN(bool); |
241 | bool allowTF32CuBLAS() const; |
242 | void setAllowTF32CuBLAS(bool); |
243 | Float32MatmulPrecision float32MatmulPrecision() const; |
244 | void setFloat32MatmulPrecision(Float32MatmulPrecision p); |
245 | bool allowFP16ReductionCuBLAS() const; |
246 | void setAllowFP16ReductionCuBLAS(bool); |
247 | bool allowBF16ReductionCuBLAS() const; |
248 | void setAllowBF16ReductionCuBLAS(bool); |
249 | at::QEngine qEngine() const; |
250 | void setQEngine(at::QEngine e); |
251 | static const std::vector<at::QEngine>& supportedQEngines(); |
252 | static bool isXNNPACKAvailable(); |
253 | void setCheckSparseTensorInvariants(bool e); |
254 | bool checkSparseTensorInvariants() const; |
255 | // This method is used to release the original weight after pre-packing. |
256 | // It should be called once before loading/running the model. |
257 | // NB: By default it is set to true for mobile builds. |
258 | void setReleaseWeightsWhenPrepacking(bool e); |
259 | bool releaseWeightsWhenPrepacking() const; |
260 | |
261 | void setDisplayVmapFallbackWarnings(bool enabled); |
262 | bool areVmapFallbackWarningsEnabled() const; |
263 | |
264 | void setDefaultMobileCPUAllocator(); |
265 | void unsetDefaultMobileCPUAllocator(); |
266 | |
267 | private: |
268 | void initCUDAIfNeeded(DeviceType p) { |
269 | if (p == DeviceType::CUDA) { |
270 | lazyInitCUDA(); |
271 | } |
272 | } |
273 | void initHIPIfNeeded(DeviceType p) { |
274 | if (p == DeviceType::HIP) { |
275 | lazyInitHIP(); |
276 | } |
277 | } |
278 | static bool checkCuBLASConfigDeterministic(); |
279 | c10::once_flag thc_init; |
280 | c10::once_flag thh_init; |
281 | bool enabled_cudnn = true; |
282 | bool deterministic_cudnn = false; |
283 | bool _deterministic_algorithms = false; |
284 | bool _deterministic_algorithms_warn_only = false; |
285 | bool enabled_flashSDP = true; |
286 | bool enabled_mem_efficientSDP = true; |
287 | bool enabled_mathSDP = true; |
288 | #ifdef USE_ROCM |
289 | bool benchmark_cudnn = true; |
290 | #else |
291 | bool benchmark_cudnn = false; |
292 | #endif |
293 | Float32MatmulPrecision float32_matmul_precision = |
294 | c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" ) == true |
295 | ? at::Float32MatmulPrecision::HIGH |
296 | : at::Float32MatmulPrecision::HIGHEST; |
297 | int benchmark_limit_cudnn = 10; |
298 | bool allow_tf32_cudnn = true; |
299 | bool allow_fp16_reduction_cublas = true; |
300 | bool allow_bf16_reduction_cublas = true; |
301 | bool enabled_mkldnn = true; |
302 | at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default; |
303 | #ifdef C10_MOBILE |
304 | bool release_original_weights = true; |
305 | #else |
306 | bool release_original_weights = false; |
307 | #endif |
308 | bool display_vmap_fallback_warnings_ = false; |
309 | c10::optional<at::QEngine> quantized_engine = c10::nullopt; |
310 | bool enable_sparse_tensor_invariant_checks = false; |
311 | |
312 | Allocator* prev_allocator_ptr_{nullptr}; |
313 | }; |
314 | |
315 | TORCH_API Context& globalContext(); |
316 | |
317 | static inline void init() { |
318 | globalContext(); |
319 | } |
320 | |
321 | TORCH_API Allocator* getCPUAllocator(); |
322 | |
323 | static inline DeprecatedTypeProperties& getDeprecatedTypeProperties( |
324 | Backend p, |
325 | ScalarType s) { |
326 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
327 | p, s); |
328 | } |
329 | |
330 | static inline DeprecatedTypeProperties& CPU(ScalarType s) { |
331 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
332 | Backend::CPU, s); |
333 | } |
334 | |
335 | static inline DeprecatedTypeProperties& CUDA(ScalarType s) { |
336 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
337 | Backend::CUDA, s); |
338 | } |
339 | |
340 | static inline DeprecatedTypeProperties& HIP(ScalarType s) { |
341 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
342 | Backend::HIP, s); |
343 | } |
344 | |
345 | static inline DeprecatedTypeProperties& MPS(ScalarType s) { |
346 | return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( |
347 | Backend::MPS, s); |
348 | } |
349 | |
350 | static inline bool hasCUDA() { |
351 | return globalContext().hasCUDA(); |
352 | } |
353 | |
354 | static inline bool hasHIP() { |
355 | return globalContext().hasHIP(); |
356 | } |
357 | |
358 | static inline bool hasIPU() { |
359 | return globalContext().hasIPU(); |
360 | } |
361 | |
362 | static inline bool hasXLA() { |
363 | return globalContext().hasXLA(); |
364 | } |
365 | |
366 | static inline bool hasMPS() { |
367 | return globalContext().hasMPS(); |
368 | } |
369 | |
370 | static inline bool hasORT() { |
371 | return globalContext().hasORT(); |
372 | } |
373 | |
374 | // Despite its name, this function returns the number of *CUDA* GPUs. |
375 | static inline size_t getNumGPUs() { |
376 | // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS |
377 | // FUNCTION. If you are interested in interrogating the number of |
378 | // devices for a specific device type, add that function to the |
379 | // relevant library (e.g., similar to at::cuda::device_count()) |
380 | if (hasCUDA() && hasHIP()) { |
381 | throw std::runtime_error( |
382 | "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades " |
383 | "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually " |
384 | "means HIP. Rebuild PyTorch with one or the other disabled." ); |
385 | } else if (hasCUDA()) { |
386 | return detail::getCUDAHooks().getNumGPUs(); |
387 | } else if (hasHIP()) { |
388 | return detail::getHIPHooks().getNumGPUs(); |
389 | } else { |
390 | return 0; |
391 | } |
392 | } |
393 | |
394 | static inline bool hasOpenMP() { |
395 | return globalContext().hasOpenMP(); |
396 | } |
397 | |
398 | static inline bool hasMKL() { |
399 | return globalContext().hasMKL(); |
400 | } |
401 | |
402 | static inline bool hasLAPACK() { |
403 | return globalContext().hasLAPACK(); |
404 | } |
405 | |
406 | static inline bool hasMAGMA() { |
407 | return globalContext().hasMAGMA(); |
408 | } |
409 | |
410 | static inline bool hasMKLDNN() { |
411 | return globalContext().hasMKLDNN(); |
412 | } |
413 | |
414 | static inline void manual_seed(uint64_t seed) { |
415 | auto gen = globalContext().defaultGenerator(DeviceType::CPU); |
416 | { |
417 | // See Note [Acquire lock when using random generators] |
418 | std::lock_guard<std::mutex> lock(gen.mutex()); |
419 | gen.set_current_seed(seed); |
420 | } |
421 | // NB: Sometimes we build with CUDA, but we don't have any GPUs |
422 | // available. In that case, we must not seed CUDA; it will fail! |
423 | const auto num_gpus = detail::getCUDAHooks().getNumGPUs(); |
424 | if (hasCUDA() && num_gpus > 0) { |
425 | for (const auto i : c10::irange(num_gpus)) { |
426 | auto cuda_gen = globalContext().defaultGenerator( |
427 | Device(at::kCUDA, static_cast<c10::DeviceIndex>(i))); |
428 | { |
429 | // See Note [Acquire lock when using random generators] |
430 | std::lock_guard<std::mutex> lock(cuda_gen.mutex()); |
431 | cuda_gen.set_current_seed(seed); |
432 | } |
433 | } |
434 | } |
435 | |
436 | if (hasMPS()) { |
437 | auto mps_gen = globalContext().defaultGenerator(DeviceType::MPS); |
438 | // See Note [Acquire lock when using random generators] |
439 | std::lock_guard<std::mutex> lock(mps_gen.mutex()); |
440 | mps_gen.set_current_seed(seed); |
441 | } |
442 | } |
443 | |
444 | // When the global flag `allow_tf32` is set to true, cuBLAS handles are |
445 | // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH. |
446 | // For some operators, such as addmv, TF32 offers no performance improvement |
447 | // but causes precision loss. To help this case, this class implements |
448 | // a RAII guard that can be used to quickly disable TF32 within its scope. |
449 | // |
450 | // Usage: |
451 | // NoTF32Guard disable_tf32; |
452 | struct TORCH_API NoTF32Guard { |
453 | NoTF32Guard(); |
454 | ~NoTF32Guard(); |
455 | static bool should_disable_tf32(); |
456 | |
457 | private: |
458 | bool changed = false; |
459 | }; |
460 | |
461 | #ifdef USE_ROCM |
462 | struct TORCH_API ROCmBackwardPassGuard { |
463 | ROCmBackwardPassGuard(); |
464 | ~ROCmBackwardPassGuard(); |
465 | static bool is_backward_pass(); |
466 | |
467 | private: |
468 | static thread_local bool is_backward_pass_; |
469 | }; |
470 | #endif |
471 | |
472 | } // namespace at |
473 | |