1#include <ATen/Config.h>
2
3#include <ATen/Context.h>
4
5#include <c10/core/CPUAllocator.h>
6
7#include <algorithm>
8#include <cctype>
9#include <string>
10
11#include <ATen/cpu/FlushDenormal.h>
12
13#ifdef USE_FBGEMM
14#include <fbgemm/Fbgemm.h>
15#endif // USE_FBGEMM
16
17namespace at {
18
19Context::Context() = default;
20
21// TODO: This could be bad juju if someone calls globalContext() in the
22// destructor of an object with static lifetime.
23Context& globalContext() {
24 static Context globalContext_;
25 return globalContext_;
26}
27
28// NB: This method is *purely* whether or not a user requested
29// that CuDNN was enabled, it doesn't actually say anything about
30// whether or not CuDNN is actually usable.
31bool Context::userEnabledCuDNN() const {
32 return enabled_cudnn;
33}
34
35void Context::setUserEnabledCuDNN(bool e) {
36 enabled_cudnn = e;
37}
38
39bool Context::userEnabledMkldnn() const {
40 return enabled_mkldnn;
41}
42
43void Context::setUserEnabledMkldnn(bool e) {
44 enabled_mkldnn = e;
45}
46
47bool Context::deterministicCuDNN() const {
48 return deterministic_cudnn;
49}
50
51void Context::setDeterministicCuDNN(bool b) {
52 deterministic_cudnn = b;
53}
54
55bool Context::deterministicAlgorithms() const {
56 return _deterministic_algorithms;
57}
58
59bool Context::deterministicAlgorithmsWarnOnly() const {
60 return _deterministic_algorithms_warn_only;
61}
62
63void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) {
64 _deterministic_algorithms = b;
65 _deterministic_algorithms_warn_only = warn_only;
66}
67
68void Context::alertNotDeterministic(c10::string_view const& caller) {
69 if (globalContext().deterministicAlgorithms()) {
70 if (globalContext().deterministicAlgorithmsWarnOnly()) {
71 TORCH_WARN(
72 caller, " does not have a deterministic implementation, but you set "
73 "'torch.use_deterministic_algorithms(True, warn_only=True)'. "
74 "You can file an issue at https://github.com/pytorch/pytorch/issues "
75 "to help us prioritize adding deterministic support for this operation.");
76 } else {
77 TORCH_CHECK(false,
78 caller, " does not have a deterministic implementation, but you set "
79 "'torch.use_deterministic_algorithms(True)'. You can turn off "
80 "determinism just for this operation, or you can use the "
81 "'warn_only=True' option, if that's acceptable for your application. "
82 "You can also file an issue at https://github.com/pytorch/pytorch/issues "
83 "to help us prioritize adding deterministic support for this operation.");
84 }
85 }
86}
87
88bool Context::allowTF32CuDNN() const {
89 return allow_tf32_cudnn;
90}
91
92void Context::setAllowTF32CuDNN(bool b) {
93 allow_tf32_cudnn = b;
94}
95
96bool Context::userEnabledFlashSDP() const {
97 return enabled_flashSDP;
98}
99
100void Context::setSDPUseFlash(bool e) {
101 enabled_flashSDP = e;
102}
103
104bool Context::userEnabledMemEfficientSDP() const {
105 return enabled_mem_efficientSDP;
106}
107
108void Context::setSDPUseMemEfficient(bool e) {
109 enabled_mem_efficientSDP = e;
110}
111
112bool Context::userEnabledMathSDP() const {
113 return enabled_mathSDP;
114}
115
116void Context::setSDPUseMath(bool e) {
117 enabled_mathSDP = e;
118}
119
120// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
121static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
122// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
123static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
124
125bool Context::checkCuBLASConfigDeterministic() {
126 bool cublas_config_deterministic = true;
127 // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
128 // is set to deterministic setting
129 if (hasCUDART() && (versionCUDART() >= 10020)) {
130 char* workspace_config = std::getenv(cublas_config_var_name);
131 cublas_config_deterministic = (workspace_config != nullptr) && (
132 (strcmp(workspace_config, cublas_deterministic_configs[0]) == 0)
133 || (strcmp(workspace_config, cublas_deterministic_configs[1]) == 0)
134 );
135 }
136 return cublas_config_deterministic;
137}
138
139void Context::alertCuBLASConfigNotDeterministic() const {
140 static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
141 if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) {
142 return;
143 }
144
145 auto msg = c10::str(
146 "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
147 "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
148 "it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
149 "case, you must set an environment variable before running your PyTorch application: ",
150 cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
151 cublas_config_var_name, "=", cublas_deterministic_configs[1], ". For more information, go to ",
152 "https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility"
153 );
154
155 if (deterministicAlgorithmsWarnOnly()) {
156 TORCH_WARN(msg);
157 } else {
158 TORCH_CHECK(false, msg);
159 }
160}
161
162bool Context::benchmarkCuDNN() const {
163 return benchmark_cudnn;
164}
165
166void Context::setBenchmarkCuDNN(bool b) {
167 benchmark_cudnn = b;
168}
169
170int Context::benchmarkLimitCuDNN() const {
171 return benchmark_limit_cudnn;
172}
173
174void Context::setBenchmarkLimitCuDNN(int b) {
175 benchmark_limit_cudnn = b;
176}
177
178bool Context::allowTF32CuBLAS() const {
179 return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
180}
181
182void Context::setAllowTF32CuBLAS(bool b) {
183 float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
184}
185
186Float32MatmulPrecision Context::float32MatmulPrecision() const {
187 return float32_matmul_precision;
188}
189
190void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
191 float32_matmul_precision = p;
192}
193
194void Context::setFloat32MatmulPrecision(const std::string &s) {
195 auto match = [this](const std::string & s_) {
196 // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
197 if (s_ == "highest") {
198 float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
199 return true;
200 } else if (s_ == "high") {
201 float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
202 return true;
203 } else if (s_ == "medium") {
204 float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
205 return true;
206 }
207 return false;
208 };
209 if (match(s)) { return; }
210 std::string sl;
211 std::transform(s.begin(), s.end(), sl.begin(),
212 [](unsigned char c) -> unsigned char { return std::tolower(c); });
213 if (match(sl)) { return; }
214 TORCH_WARN(s, " is not one of 'highest', 'high', or 'medium'; the current"
215 "setFloat32MatmulPrecision call has no effect.");
216}
217
218at::LinalgBackend Context::linalgPreferredBackend() const {
219 return linalg_preferred_backend;
220}
221
222void Context::setLinalgPreferredBackend(at::LinalgBackend b) {
223 linalg_preferred_backend = b;
224 TORCH_CHECK((b != at::LinalgBackend::Cusolver) || hasCuSOLVER(),
225 "Cannot set preferred backend to cuSOLVER if PyTorch has not been compiled with cuSOLVER.");
226 TORCH_CHECK((b != at::LinalgBackend::Magma) || hasMAGMA(),
227 "Cannot set preferred backend to MAGMA if PyTorch has not been compiled with MAGMA.");
228 if (b != at::LinalgBackend::Default) {
229 TORCH_WARN_ONCE(
230 "torch.backends.cuda.preferred_linalg_library is an experimental feature. "
231 "If you see any error or unexpected behavior when this flag is set "
232 "please file an issue on GitHub."
233 );
234 }
235}
236
237bool Context::allowFP16ReductionCuBLAS() const {
238 return allow_fp16_reduction_cublas;
239}
240
241void Context::setAllowFP16ReductionCuBLAS(bool b) {
242 allow_fp16_reduction_cublas = b;
243}
244
245bool Context::allowBF16ReductionCuBLAS() const {
246 return allow_bf16_reduction_cublas;
247}
248
249void Context::setAllowBF16ReductionCuBLAS(bool b) {
250 allow_bf16_reduction_cublas = b;
251}
252
253
254bool Context::hasMKL() {
255#if AT_MKL_ENABLED()
256 return true;
257#else
258 return false;
259#endif
260}
261
262bool Context::hasMKLDNN() {
263#if AT_MKLDNN_ENABLED()
264 return true;
265#else
266 return false;
267#endif
268}
269
270bool Context::hasOpenMP() {
271#ifdef _OPENMP
272 return true;
273#else
274 return false;
275#endif
276}
277
278bool Context::hasLAPACK() {
279#if AT_BUILD_WITH_LAPACK()
280 return true;
281#else
282 return false;
283#endif
284}
285
286at::QEngine Context::qEngine() const {
287 static auto _quantized_engine = []() {
288 at::QEngine qengine = at::kNoQEngine;
289#if defined(C10_MOBILE) && defined(USE_PYTORCH_QNNPACK)
290 qengine = at::kQNNPACK;
291#endif
292
293#if AT_MKLDNN_ENABLED()
294 qengine = at::kONEDNN;
295#endif
296
297#ifdef USE_FBGEMM
298 if (fbgemm::fbgemmSupportedCPU()) {
299 /* X86 is enabled if and only if fbgemm is available.
300 * It combines goodness of fbgemm and onednn by dispatching.
301 * If onednn not available, always dispatch to fbgemm.
302 * Make it default qengine for X86 CPU platforms.
303 */
304 qengine = at::kX86;
305 }
306#endif
307 return qengine;
308 }();
309 return quantized_engine.value_or(_quantized_engine);
310}
311
312void Context::setQEngine(at::QEngine e) {
313 const auto& qengines = supportedQEngines();
314 if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) {
315 quantized_engine = e;
316 return;
317 }
318 TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
319}
320
321const std::vector<at::QEngine>& Context::supportedQEngines() {
322 static auto supported_qengines = []() {
323 std::vector<at::QEngine> engines = {};
324 // Engines are listed in priority order: later one wins
325 // By default we prefer FBGEMM if we're running on server side
326 // QNNPACK on server side has some issue, so we disable it by default.
327#ifdef C10_MOBILE
328 engines.push_back(at::kNoQEngine);
329#ifdef USE_PYTORCH_QNNPACK
330 engines.push_back(at::kQNNPACK);
331#endif
332#else // C10_MOBILE
333#ifdef USE_PYTORCH_QNNPACK
334 engines.push_back(at::kQNNPACK);
335#endif
336 engines.push_back(at::kNoQEngine);
337#endif // C10_MOBILE
338
339#if AT_MKLDNN_ENABLED()
340 engines.push_back(at::kONEDNN);
341#endif
342
343#ifdef USE_FBGEMM
344 if (fbgemm::fbgemmSupportedCPU()) {
345 engines.push_back(at::kX86);
346 // The X86 qengine is available if and only if FBGEMM is available
347 engines.push_back(at::kFBGEMM);
348 }
349#endif
350
351 return engines;
352 }();
353 return supported_qengines;
354}
355
356bool Context::isXNNPACKAvailable() {
357#ifdef USE_XNNPACK
358 return true;
359#else
360 return false;
361#endif
362}
363
364void Context::setCheckSparseTensorInvariants(bool e) {
365 enable_sparse_tensor_invariant_checks = e;
366}
367
368bool Context::checkSparseTensorInvariants() const {
369 return enable_sparse_tensor_invariant_checks;
370}
371
372bool Context::releaseWeightsWhenPrepacking() const {
373 return release_original_weights;
374}
375
376void Context::setReleaseWeightsWhenPrepacking(bool e) {
377 release_original_weights = e;
378}
379
380bool Context::setFlushDenormal(bool on) {
381 return at::cpu::set_flush_denormal(on);
382}
383
384Allocator* getCPUAllocator() {
385 return c10::GetCPUAllocator();
386}
387
388// override_allow_tf32_flag = true
389// means the allow_tf32 flags are overrided and tf32 is force disabled
390// override_allow_tf32_flag = false
391// means the original allow_tf32 flags are followed
392thread_local bool override_allow_tf32_flag = false;
393
394NoTF32Guard::NoTF32Guard() {
395 if (!override_allow_tf32_flag) {
396 changed = true;
397 override_allow_tf32_flag = true;
398 }
399}
400
401NoTF32Guard::~NoTF32Guard() {
402 if (changed) {
403 override_allow_tf32_flag = false;
404 }
405}
406
407bool NoTF32Guard::should_disable_tf32() {
408 return override_allow_tf32_flag;
409}
410
411#ifdef USE_ROCM
412// Ops can query this flag to know they are in the backward pass.
413// This information can be used, for example, to select implementations
414// with different numerical or performance characteristics.
415// See https://pytorch.org/docs/stable/notes/numerical_accuracy.html for details.
416thread_local bool ROCmBackwardPassGuard::is_backward_pass_;
417
418ROCmBackwardPassGuard::ROCmBackwardPassGuard() {
419 is_backward_pass_ = true;
420}
421
422ROCmBackwardPassGuard::~ROCmBackwardPassGuard() {
423 is_backward_pass_ = false;
424}
425
426bool ROCmBackwardPassGuard::is_backward_pass() {
427 return is_backward_pass_;
428}
429#endif
430
431bool Context::areVmapFallbackWarningsEnabled() const {
432 return display_vmap_fallback_warnings_;
433}
434
435void Context::setDisplayVmapFallbackWarnings(bool enabled) {
436 display_vmap_fallback_warnings_ = enabled;
437}
438
439void Context::setDefaultMobileCPUAllocator() {
440 TORCH_CHECK(prev_allocator_ptr_ == nullptr,
441 "Already within the scope of another non-default cpu allocator."
442 "Cannot set another allocator.");
443 // Setting the priority high to make sure no other allocator gets used instead of this.
444 prev_allocator_ptr_ = c10::GetCPUAllocator();
445 c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100);
446}
447
448void Context::unsetDefaultMobileCPUAllocator() {
449 TORCH_CHECK(prev_allocator_ptr_ != nullptr,
450 "setDefaultMobileCPUAllocator must have been called "
451 "before unsetDefaultMobileCPUAllocator.");
452 // Setting the priority high to make sure no other allocator gets used instead of this.
453 c10::SetCPUAllocator(prev_allocator_ptr_ , /*priority*/ 100);
454 prev_allocator_ptr_ = nullptr;
455}
456} // namespace at
457