1 | #pragma once |
2 | |
3 | #include <ATen/Parallel.h> |
4 | #include <ATen/record_function.h> |
5 | #include <torch/csrc/api/include/torch/types.h> |
6 | #include <torch/csrc/autograd/grad_mode.h> |
7 | #include <torch/csrc/autograd/profiler.h> |
8 | #include <cstdint> |
9 | |
10 | namespace torch { |
11 | |
12 | /// A RAII, thread-local guard that disabled gradient calculation. |
13 | /// |
14 | /// Disabling gradient calculation is useful for inference, when you are sure |
15 | /// that you will not call `at::Tensor::backward`. It will reduce memory |
16 | /// consumption for computations that would otherwise have `requires_grad() == |
17 | /// true`. |
18 | /// |
19 | /// In this mode, the result of every computation will have |
20 | /// `requires_grad() == false`, even when the inputs have `requires_grad() == |
21 | /// true`. |
22 | /// |
23 | /// This context manager is thread-local; it will not affect computation |
24 | /// in other threads. |
25 | /// |
26 | /// Example: |
27 | /// @code |
28 | /// auto x = torch::tensor({1.}, torch::requires_grad()); |
29 | /// { |
30 | /// torch::NoGradGuard no_grad; |
31 | /// auto y = x * 2; |
32 | /// std::cout << y.requires_grad() << std::endl; // prints `false` |
33 | /// } |
34 | /// { |
35 | /// auto doubler = [](torch::Tensor x) { |
36 | /// torch::NoGradGuard no_grad; |
37 | /// return x * 2; |
38 | /// }; |
39 | /// auto z = doubler(x); |
40 | /// std::cout << z.requires_grad() << std::endl; // prints `false` |
41 | /// } |
42 | /// @endcode |
43 | using NoGradGuard = at::NoGradGuard; |
44 | |
45 | /// A RAII, thread-local guard that sets gradient calculation to on or off. |
46 | /// |
47 | /// ``AutoGradMode`` will enable or disable grads based on its argument |
48 | /// `enabled`. |
49 | /// |
50 | /// This context manager is thread-local; it will not affect computation |
51 | /// in other threads. |
52 | /// |
53 | /// \param enabled: Flag whether to enable grad (``true``), or disable |
54 | /// (``false``). This can be used to conditionally enable |
55 | /// gradients. |
56 | /// |
57 | /// Example: |
58 | /// @code |
59 | /// auto x = torch::tensor({1.}, torch::requires_grad()); |
60 | /// { |
61 | /// torch::AutoGradMode enable_grad(true); |
62 | /// auto y = x * 2; |
63 | /// std::cout << y.requires_grad() << std::endl; // prints `true` |
64 | /// } |
65 | /// { |
66 | /// torch::AutoGradMode enable_grad(false); |
67 | /// auto y = x * 2; |
68 | /// std::cout << y.requires_grad() << std::endl; // prints `false` |
69 | /// } |
70 | /// @endcode |
71 | using AutoGradMode = at::AutoGradMode; |
72 | |
73 | /// Sets the global random seed for all newly created CPU and CUDA tensors. |
74 | using at::manual_seed; |
75 | |
76 | // Called during new thread initialization |
77 | using at::init_num_threads; |
78 | |
79 | // Returns the number of threads used in parallel region. |
80 | using at::get_num_threads; |
81 | |
82 | // Sets the number of threads to be used in parallel region. |
83 | using at::set_num_threads; |
84 | |
85 | // Returns the number of threads used for inter-op parallelism. |
86 | using at::get_num_interop_threads; |
87 | |
88 | // Sets the number of threads to be used for inter-op parallelism. |
89 | using at::set_num_interop_threads; |
90 | |
91 | // Returns true if both t1, t2 are undefined or both are defined and equal |
92 | inline bool equal_if_defined(Tensor t1, Tensor t2) { |
93 | return ( |
94 | (!t1.defined() && !t2.defined()) || |
95 | (t1.defined() && t2.defined() && torch::equal(t1, t2))); |
96 | } |
97 | |
98 | // RecordFunction API |
99 | using at::addGlobalCallback; |
100 | using at::addThreadLocalCallback; |
101 | using at::CallbackHandle; |
102 | using at::clearCallbacks; |
103 | using at::clearGlobalCallbacks; |
104 | using at::clearThreadLocalCallbacks; |
105 | using at::DisableRecordFunctionGuard; |
106 | using at::enableRecordFunction; |
107 | using at::hasCallbacks; |
108 | using at::hasGlobalCallbacks; |
109 | using at::hasThreadLocalCallbacks; |
110 | using at::isRecordFunctionEnabled; |
111 | using at::RecordFunction; |
112 | using at::RecordFunctionCallback; |
113 | using at::RecordFunctionGuard; |
114 | using at::removeCallback; |
115 | |
116 | } // namespace torch |
117 | |