1#pragma once
2
3#include <ATen/Parallel.h>
4#include <ATen/record_function.h>
5#include <torch/csrc/api/include/torch/types.h>
6#include <torch/csrc/autograd/grad_mode.h>
7#include <torch/csrc/autograd/profiler.h>
8#include <cstdint>
9
10namespace torch {
11
12/// A RAII, thread-local guard that disabled gradient calculation.
13///
14/// Disabling gradient calculation is useful for inference, when you are sure
15/// that you will not call `at::Tensor::backward`. It will reduce memory
16/// consumption for computations that would otherwise have `requires_grad() ==
17/// true`.
18///
19/// In this mode, the result of every computation will have
20/// `requires_grad() == false`, even when the inputs have `requires_grad() ==
21/// true`.
22///
23/// This context manager is thread-local; it will not affect computation
24/// in other threads.
25///
26/// Example:
27/// @code
28/// auto x = torch::tensor({1.}, torch::requires_grad());
29/// {
30/// torch::NoGradGuard no_grad;
31/// auto y = x * 2;
32/// std::cout << y.requires_grad() << std::endl; // prints `false`
33/// }
34/// {
35/// auto doubler = [](torch::Tensor x) {
36/// torch::NoGradGuard no_grad;
37/// return x * 2;
38/// };
39/// auto z = doubler(x);
40/// std::cout << z.requires_grad() << std::endl; // prints `false`
41/// }
42/// @endcode
43using NoGradGuard = at::NoGradGuard;
44
45/// A RAII, thread-local guard that sets gradient calculation to on or off.
46///
47/// ``AutoGradMode`` will enable or disable grads based on its argument
48/// `enabled`.
49///
50/// This context manager is thread-local; it will not affect computation
51/// in other threads.
52///
53/// \param enabled: Flag whether to enable grad (``true``), or disable
54/// (``false``). This can be used to conditionally enable
55/// gradients.
56///
57/// Example:
58/// @code
59/// auto x = torch::tensor({1.}, torch::requires_grad());
60/// {
61/// torch::AutoGradMode enable_grad(true);
62/// auto y = x * 2;
63/// std::cout << y.requires_grad() << std::endl; // prints `true`
64/// }
65/// {
66/// torch::AutoGradMode enable_grad(false);
67/// auto y = x * 2;
68/// std::cout << y.requires_grad() << std::endl; // prints `false`
69/// }
70/// @endcode
71using AutoGradMode = at::AutoGradMode;
72
73/// Sets the global random seed for all newly created CPU and CUDA tensors.
74using at::manual_seed;
75
76// Called during new thread initialization
77using at::init_num_threads;
78
79// Returns the number of threads used in parallel region.
80using at::get_num_threads;
81
82// Sets the number of threads to be used in parallel region.
83using at::set_num_threads;
84
85// Returns the number of threads used for inter-op parallelism.
86using at::get_num_interop_threads;
87
88// Sets the number of threads to be used for inter-op parallelism.
89using at::set_num_interop_threads;
90
91// Returns true if both t1, t2 are undefined or both are defined and equal
92inline bool equal_if_defined(Tensor t1, Tensor t2) {
93 return (
94 (!t1.defined() && !t2.defined()) ||
95 (t1.defined() && t2.defined() && torch::equal(t1, t2)));
96}
97
98// RecordFunction API
99using at::addGlobalCallback;
100using at::addThreadLocalCallback;
101using at::CallbackHandle;
102using at::clearCallbacks;
103using at::clearGlobalCallbacks;
104using at::clearThreadLocalCallbacks;
105using at::DisableRecordFunctionGuard;
106using at::enableRecordFunction;
107using at::hasCallbacks;
108using at::hasGlobalCallbacks;
109using at::hasThreadLocalCallbacks;
110using at::isRecordFunctionEnabled;
111using at::RecordFunction;
112using at::RecordFunctionCallback;
113using at::RecordFunctionGuard;
114using at::removeCallback;
115
116} // namespace torch
117