1 | #include <ATen/Config.h> |
2 | |
3 | #include <ATen/Context.h> |
4 | |
5 | #include <c10/core/CPUAllocator.h> |
6 | |
7 | #include <algorithm> |
8 | #include <cctype> |
9 | #include <string> |
10 | |
11 | #include <ATen/cpu/FlushDenormal.h> |
12 | |
13 | #ifdef USE_FBGEMM |
14 | #include <fbgemm/Fbgemm.h> |
15 | #endif // USE_FBGEMM |
16 | |
17 | namespace at { |
18 | |
19 | Context::Context() = default; |
20 | |
21 | // TODO: This could be bad juju if someone calls globalContext() in the |
22 | // destructor of an object with static lifetime. |
23 | Context& globalContext() { |
24 | static Context globalContext_; |
25 | return globalContext_; |
26 | } |
27 | |
28 | // NB: This method is *purely* whether or not a user requested |
29 | // that CuDNN was enabled, it doesn't actually say anything about |
30 | // whether or not CuDNN is actually usable. |
31 | bool Context::userEnabledCuDNN() const { |
32 | return enabled_cudnn; |
33 | } |
34 | |
35 | void Context::setUserEnabledCuDNN(bool e) { |
36 | enabled_cudnn = e; |
37 | } |
38 | |
39 | bool Context::userEnabledMkldnn() const { |
40 | return enabled_mkldnn; |
41 | } |
42 | |
43 | void Context::setUserEnabledMkldnn(bool e) { |
44 | enabled_mkldnn = e; |
45 | } |
46 | |
47 | bool Context::deterministicCuDNN() const { |
48 | return deterministic_cudnn; |
49 | } |
50 | |
51 | void Context::setDeterministicCuDNN(bool b) { |
52 | deterministic_cudnn = b; |
53 | } |
54 | |
55 | bool Context::deterministicAlgorithms() const { |
56 | return _deterministic_algorithms; |
57 | } |
58 | |
59 | bool Context::deterministicAlgorithmsWarnOnly() const { |
60 | return _deterministic_algorithms_warn_only; |
61 | } |
62 | |
63 | void Context::setDeterministicAlgorithms(bool b, bool warn_only=false) { |
64 | _deterministic_algorithms = b; |
65 | _deterministic_algorithms_warn_only = warn_only; |
66 | } |
67 | |
68 | void Context::alertNotDeterministic(c10::string_view const& caller) { |
69 | if (globalContext().deterministicAlgorithms()) { |
70 | if (globalContext().deterministicAlgorithmsWarnOnly()) { |
71 | TORCH_WARN( |
72 | caller, " does not have a deterministic implementation, but you set " |
73 | "'torch.use_deterministic_algorithms(True, warn_only=True)'. " |
74 | "You can file an issue at https://github.com/pytorch/pytorch/issues " |
75 | "to help us prioritize adding deterministic support for this operation." ); |
76 | } else { |
77 | TORCH_CHECK(false, |
78 | caller, " does not have a deterministic implementation, but you set " |
79 | "'torch.use_deterministic_algorithms(True)'. You can turn off " |
80 | "determinism just for this operation, or you can use the " |
81 | "'warn_only=True' option, if that's acceptable for your application. " |
82 | "You can also file an issue at https://github.com/pytorch/pytorch/issues " |
83 | "to help us prioritize adding deterministic support for this operation." ); |
84 | } |
85 | } |
86 | } |
87 | |
88 | bool Context::allowTF32CuDNN() const { |
89 | return allow_tf32_cudnn; |
90 | } |
91 | |
92 | void Context::setAllowTF32CuDNN(bool b) { |
93 | allow_tf32_cudnn = b; |
94 | } |
95 | |
96 | bool Context::userEnabledFlashSDP() const { |
97 | return enabled_flashSDP; |
98 | } |
99 | |
100 | void Context::setSDPUseFlash(bool e) { |
101 | enabled_flashSDP = e; |
102 | } |
103 | |
104 | bool Context::userEnabledMemEfficientSDP() const { |
105 | return enabled_mem_efficientSDP; |
106 | } |
107 | |
108 | void Context::setSDPUseMemEfficient(bool e) { |
109 | enabled_mem_efficientSDP = e; |
110 | } |
111 | |
112 | bool Context::userEnabledMathSDP() const { |
113 | return enabled_mathSDP; |
114 | } |
115 | |
116 | void Context::setSDPUseMath(bool e) { |
117 | enabled_mathSDP = e; |
118 | } |
119 | |
120 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
121 | static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG" ; |
122 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
123 | static const char* const cublas_deterministic_configs[] = { ":4096:8" , ":16:8" }; |
124 | |
125 | bool Context::checkCuBLASConfigDeterministic() { |
126 | bool cublas_config_deterministic = true; |
127 | // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config |
128 | // is set to deterministic setting |
129 | if (hasCUDART() && (versionCUDART() >= 10020)) { |
130 | char* workspace_config = std::getenv(cublas_config_var_name); |
131 | cublas_config_deterministic = (workspace_config != nullptr) && ( |
132 | (strcmp(workspace_config, cublas_deterministic_configs[0]) == 0) |
133 | || (strcmp(workspace_config, cublas_deterministic_configs[1]) == 0) |
134 | ); |
135 | } |
136 | return cublas_config_deterministic; |
137 | } |
138 | |
139 | void Context::alertCuBLASConfigNotDeterministic() const { |
140 | static bool cublas_config_deterministic = checkCuBLASConfigDeterministic(); |
141 | if (C10_LIKELY(!deterministicAlgorithms() || cublas_config_deterministic)) { |
142 | return; |
143 | } |
144 | |
145 | auto msg = c10::str( |
146 | "Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or " , |
147 | "`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because " , |
148 | "it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this " , |
149 | "case, you must set an environment variable before running your PyTorch application: " , |
150 | cublas_config_var_name, "=" , cublas_deterministic_configs[0], " or " , |
151 | cublas_config_var_name, "=" , cublas_deterministic_configs[1], ". For more information, go to " , |
152 | "https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility" |
153 | ); |
154 | |
155 | if (deterministicAlgorithmsWarnOnly()) { |
156 | TORCH_WARN(msg); |
157 | } else { |
158 | TORCH_CHECK(false, msg); |
159 | } |
160 | } |
161 | |
162 | bool Context::benchmarkCuDNN() const { |
163 | return benchmark_cudnn; |
164 | } |
165 | |
166 | void Context::setBenchmarkCuDNN(bool b) { |
167 | benchmark_cudnn = b; |
168 | } |
169 | |
170 | int Context::benchmarkLimitCuDNN() const { |
171 | return benchmark_limit_cudnn; |
172 | } |
173 | |
174 | void Context::setBenchmarkLimitCuDNN(int b) { |
175 | benchmark_limit_cudnn = b; |
176 | } |
177 | |
178 | bool Context::allowTF32CuBLAS() const { |
179 | return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; |
180 | } |
181 | |
182 | void Context::setAllowTF32CuBLAS(bool b) { |
183 | float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; |
184 | } |
185 | |
186 | Float32MatmulPrecision Context::float32MatmulPrecision() const { |
187 | return float32_matmul_precision; |
188 | } |
189 | |
190 | void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) { |
191 | float32_matmul_precision = p; |
192 | } |
193 | |
194 | void Context::setFloat32MatmulPrecision(const std::string &s) { |
195 | auto match = [this](const std::string & s_) { |
196 | // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention |
197 | if (s_ == "highest" ) { |
198 | float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; |
199 | return true; |
200 | } else if (s_ == "high" ) { |
201 | float32_matmul_precision = at::Float32MatmulPrecision::HIGH; |
202 | return true; |
203 | } else if (s_ == "medium" ) { |
204 | float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM; |
205 | return true; |
206 | } |
207 | return false; |
208 | }; |
209 | if (match(s)) { return; } |
210 | std::string sl; |
211 | std::transform(s.begin(), s.end(), sl.begin(), |
212 | [](unsigned char c) -> unsigned char { return std::tolower(c); }); |
213 | if (match(sl)) { return; } |
214 | TORCH_WARN(s, " is not one of 'highest', 'high', or 'medium'; the current" |
215 | "setFloat32MatmulPrecision call has no effect." ); |
216 | } |
217 | |
218 | at::LinalgBackend Context::linalgPreferredBackend() const { |
219 | return linalg_preferred_backend; |
220 | } |
221 | |
222 | void Context::setLinalgPreferredBackend(at::LinalgBackend b) { |
223 | linalg_preferred_backend = b; |
224 | TORCH_CHECK((b != at::LinalgBackend::Cusolver) || hasCuSOLVER(), |
225 | "Cannot set preferred backend to cuSOLVER if PyTorch has not been compiled with cuSOLVER." ); |
226 | TORCH_CHECK((b != at::LinalgBackend::Magma) || hasMAGMA(), |
227 | "Cannot set preferred backend to MAGMA if PyTorch has not been compiled with MAGMA." ); |
228 | if (b != at::LinalgBackend::Default) { |
229 | TORCH_WARN_ONCE( |
230 | "torch.backends.cuda.preferred_linalg_library is an experimental feature. " |
231 | "If you see any error or unexpected behavior when this flag is set " |
232 | "please file an issue on GitHub." |
233 | ); |
234 | } |
235 | } |
236 | |
237 | bool Context::allowFP16ReductionCuBLAS() const { |
238 | return allow_fp16_reduction_cublas; |
239 | } |
240 | |
241 | void Context::setAllowFP16ReductionCuBLAS(bool b) { |
242 | allow_fp16_reduction_cublas = b; |
243 | } |
244 | |
245 | bool Context::allowBF16ReductionCuBLAS() const { |
246 | return allow_bf16_reduction_cublas; |
247 | } |
248 | |
249 | void Context::setAllowBF16ReductionCuBLAS(bool b) { |
250 | allow_bf16_reduction_cublas = b; |
251 | } |
252 | |
253 | |
254 | bool Context::hasMKL() { |
255 | #if AT_MKL_ENABLED() |
256 | return true; |
257 | #else |
258 | return false; |
259 | #endif |
260 | } |
261 | |
262 | bool Context::hasMKLDNN() { |
263 | #if AT_MKLDNN_ENABLED() |
264 | return true; |
265 | #else |
266 | return false; |
267 | #endif |
268 | } |
269 | |
270 | bool Context::hasOpenMP() { |
271 | #ifdef _OPENMP |
272 | return true; |
273 | #else |
274 | return false; |
275 | #endif |
276 | } |
277 | |
278 | bool Context::hasLAPACK() { |
279 | #if AT_BUILD_WITH_LAPACK() |
280 | return true; |
281 | #else |
282 | return false; |
283 | #endif |
284 | } |
285 | |
286 | at::QEngine Context::qEngine() const { |
287 | static auto _quantized_engine = []() { |
288 | at::QEngine qengine = at::kNoQEngine; |
289 | #if defined(C10_MOBILE) && defined(USE_PYTORCH_QNNPACK) |
290 | qengine = at::kQNNPACK; |
291 | #endif |
292 | |
293 | #if AT_MKLDNN_ENABLED() |
294 | qengine = at::kONEDNN; |
295 | #endif |
296 | |
297 | #ifdef USE_FBGEMM |
298 | if (fbgemm::fbgemmSupportedCPU()) { |
299 | /* X86 is enabled if and only if fbgemm is available. |
300 | * It combines goodness of fbgemm and onednn by dispatching. |
301 | * If onednn not available, always dispatch to fbgemm. |
302 | * Make it default qengine for X86 CPU platforms. |
303 | */ |
304 | qengine = at::kX86; |
305 | } |
306 | #endif |
307 | return qengine; |
308 | }(); |
309 | return quantized_engine.value_or(_quantized_engine); |
310 | } |
311 | |
312 | void Context::setQEngine(at::QEngine e) { |
313 | const auto& qengines = supportedQEngines(); |
314 | if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) { |
315 | quantized_engine = e; |
316 | return; |
317 | } |
318 | TORCH_CHECK(false, "quantized engine " , toString(e), " is not supported" ); |
319 | } |
320 | |
321 | const std::vector<at::QEngine>& Context::supportedQEngines() { |
322 | static auto supported_qengines = []() { |
323 | std::vector<at::QEngine> engines = {}; |
324 | // Engines are listed in priority order: later one wins |
325 | // By default we prefer FBGEMM if we're running on server side |
326 | // QNNPACK on server side has some issue, so we disable it by default. |
327 | #ifdef C10_MOBILE |
328 | engines.push_back(at::kNoQEngine); |
329 | #ifdef USE_PYTORCH_QNNPACK |
330 | engines.push_back(at::kQNNPACK); |
331 | #endif |
332 | #else // C10_MOBILE |
333 | #ifdef USE_PYTORCH_QNNPACK |
334 | engines.push_back(at::kQNNPACK); |
335 | #endif |
336 | engines.push_back(at::kNoQEngine); |
337 | #endif // C10_MOBILE |
338 | |
339 | #if AT_MKLDNN_ENABLED() |
340 | engines.push_back(at::kONEDNN); |
341 | #endif |
342 | |
343 | #ifdef USE_FBGEMM |
344 | if (fbgemm::fbgemmSupportedCPU()) { |
345 | engines.push_back(at::kX86); |
346 | // The X86 qengine is available if and only if FBGEMM is available |
347 | engines.push_back(at::kFBGEMM); |
348 | } |
349 | #endif |
350 | |
351 | return engines; |
352 | }(); |
353 | return supported_qengines; |
354 | } |
355 | |
356 | bool Context::isXNNPACKAvailable() { |
357 | #ifdef USE_XNNPACK |
358 | return true; |
359 | #else |
360 | return false; |
361 | #endif |
362 | } |
363 | |
364 | void Context::setCheckSparseTensorInvariants(bool e) { |
365 | enable_sparse_tensor_invariant_checks = e; |
366 | } |
367 | |
368 | bool Context::checkSparseTensorInvariants() const { |
369 | return enable_sparse_tensor_invariant_checks; |
370 | } |
371 | |
372 | bool Context::releaseWeightsWhenPrepacking() const { |
373 | return release_original_weights; |
374 | } |
375 | |
376 | void Context::setReleaseWeightsWhenPrepacking(bool e) { |
377 | release_original_weights = e; |
378 | } |
379 | |
380 | bool Context::setFlushDenormal(bool on) { |
381 | return at::cpu::set_flush_denormal(on); |
382 | } |
383 | |
384 | Allocator* getCPUAllocator() { |
385 | return c10::GetCPUAllocator(); |
386 | } |
387 | |
388 | // override_allow_tf32_flag = true |
389 | // means the allow_tf32 flags are overrided and tf32 is force disabled |
390 | // override_allow_tf32_flag = false |
391 | // means the original allow_tf32 flags are followed |
392 | thread_local bool override_allow_tf32_flag = false; |
393 | |
394 | NoTF32Guard::NoTF32Guard() { |
395 | if (!override_allow_tf32_flag) { |
396 | changed = true; |
397 | override_allow_tf32_flag = true; |
398 | } |
399 | } |
400 | |
401 | NoTF32Guard::~NoTF32Guard() { |
402 | if (changed) { |
403 | override_allow_tf32_flag = false; |
404 | } |
405 | } |
406 | |
407 | bool NoTF32Guard::should_disable_tf32() { |
408 | return override_allow_tf32_flag; |
409 | } |
410 | |
411 | #ifdef USE_ROCM |
412 | // Ops can query this flag to know they are in the backward pass. |
413 | // This information can be used, for example, to select implementations |
414 | // with different numerical or performance characteristics. |
415 | // See https://pytorch.org/docs/stable/notes/numerical_accuracy.html for details. |
416 | thread_local bool ROCmBackwardPassGuard::is_backward_pass_; |
417 | |
418 | ROCmBackwardPassGuard::ROCmBackwardPassGuard() { |
419 | is_backward_pass_ = true; |
420 | } |
421 | |
422 | ROCmBackwardPassGuard::~ROCmBackwardPassGuard() { |
423 | is_backward_pass_ = false; |
424 | } |
425 | |
426 | bool ROCmBackwardPassGuard::is_backward_pass() { |
427 | return is_backward_pass_; |
428 | } |
429 | #endif |
430 | |
431 | bool Context::areVmapFallbackWarningsEnabled() const { |
432 | return display_vmap_fallback_warnings_; |
433 | } |
434 | |
435 | void Context::setDisplayVmapFallbackWarnings(bool enabled) { |
436 | display_vmap_fallback_warnings_ = enabled; |
437 | } |
438 | |
439 | void Context::setDefaultMobileCPUAllocator() { |
440 | TORCH_CHECK(prev_allocator_ptr_ == nullptr, |
441 | "Already within the scope of another non-default cpu allocator." |
442 | "Cannot set another allocator." ); |
443 | // Setting the priority high to make sure no other allocator gets used instead of this. |
444 | prev_allocator_ptr_ = c10::GetCPUAllocator(); |
445 | c10::SetCPUAllocator(c10::GetDefaultMobileCPUAllocator(), /*priority*/ 100); |
446 | } |
447 | |
448 | void Context::unsetDefaultMobileCPUAllocator() { |
449 | TORCH_CHECK(prev_allocator_ptr_ != nullptr, |
450 | "setDefaultMobileCPUAllocator must have been called " |
451 | "before unsetDefaultMobileCPUAllocator." ); |
452 | // Setting the priority high to make sure no other allocator gets used instead of this. |
453 | c10::SetCPUAllocator(prev_allocator_ptr_ , /*priority*/ 100); |
454 | prev_allocator_ptr_ = nullptr; |
455 | } |
456 | } // namespace at |
457 | |