1 | #pragma once |
2 | |
3 | // Engine implements backpropagation from output variables and their gradients |
4 | // to "root" variables (variables created by the user with requires_grad=True). |
5 | |
6 | #include <ATen/Tensor.h> |
7 | #include <ATen/ThreadLocalState.h> |
8 | #include <ATen/core/ivalue.h> |
9 | #include <torch/csrc/Export.h> |
10 | #include <torch/csrc/autograd/anomaly_mode.h> |
11 | #include <torch/csrc/autograd/function.h> |
12 | #include <torch/csrc/autograd/functions/basic_ops.h> |
13 | #include <torch/csrc/autograd/graph_task.h> |
14 | #include <torch/csrc/autograd/input_buffer.h> |
15 | #include <torch/csrc/autograd/saved_variable_hooks.h> |
16 | #include <torch/csrc/autograd/utils/warnings.h> |
17 | |
18 | #include <c10/util/CallOnce.h> |
19 | |
20 | #include <deque> |
21 | #include <exception> |
22 | #include <functional> |
23 | #include <memory> |
24 | #include <queue> |
25 | #include <thread> |
26 | #include <unordered_map> |
27 | #include <utility> |
28 | #include <vector> |
29 | |
30 | namespace torch { |
31 | namespace autograd { |
32 | struct ReadyQueue; |
33 | } |
34 | } // namespace torch |
35 | |
36 | namespace torch { |
37 | namespace autograd { |
38 | |
39 | // Maximum reentrant backward depth before switching to a new thread |
40 | // This limit is based on the TSAN's deadlock detector, where it will |
41 | // fail if a program hold more than 65 locks in one thread at once. |
42 | // As we hold mutex in every of our custom C++ autograd Node, we would |
43 | // like to avoid TSAN complains on this when doing reentrant backwards |
44 | // For reference, see https://github.com/google/sanitizers/issues/950 |
45 | static constexpr int MAX_DEPTH = 60; |
46 | |
47 | void set_device(int device); |
48 | void validate_outputs( |
49 | const edge_list& edges, |
50 | variable_list& grads, |
51 | const std::function<std::string(const std::string&)>& format_error); |
52 | |
53 | struct NodeTask { |
54 | std::weak_ptr<GraphTask> base_; |
55 | std::shared_ptr<Node> fn_; |
56 | // This buffer serves as an implicit "addition" node for all of the |
57 | // gradients flowing here. Once all the dependencies are finished, we |
58 | // use the contents of this buffer to run the function. |
59 | InputBuffer inputs_; |
60 | // When worker receives a task with isShutdownTask = true, it will immediately |
61 | // exit. The engine sends a shutdown task to every queue upon its destruction. |
62 | bool isShutdownTask_; |
63 | |
64 | int getReentrantDepth() const; |
65 | |
66 | NodeTask( |
67 | std::weak_ptr<GraphTask> base, |
68 | std::shared_ptr<Node> fn, |
69 | InputBuffer inputs, |
70 | bool isShutdownTask = false) |
71 | : base_(std::move(base)), |
72 | fn_(std::move(fn)), |
73 | inputs_(std::move(inputs)), |
74 | isShutdownTask_(isShutdownTask) {} |
75 | }; |
76 | |
77 | // Guard that sets and restores checkpoint_valid |
78 | class CheckpointValidGuard { |
79 | public: |
80 | explicit CheckpointValidGuard( |
81 | const std::shared_ptr<const GraphTask>& graph_task); |
82 | ~CheckpointValidGuard(); |
83 | |
84 | private: |
85 | bool prev_checkpoint_valid_state; |
86 | }; |
87 | |
88 | struct ReadyQueue { |
89 | private: |
90 | // Returns true when t2 should be (weakly) BEFORE t1 in the queue. |
91 | // Shutdown tasks are first and then empty NodeTask are next. |
92 | struct CompareNodeTaskTime { |
93 | bool operator()(NodeTask const& t1, NodeTask const& t2) { |
94 | // NOLINTNEXTLINE(bugprone-branch-clone) |
95 | if (t2.isShutdownTask_) { |
96 | return true; |
97 | } else if (!t1.fn_ || t1.isShutdownTask_) { |
98 | return false; |
99 | } else if (!t2.fn_) { |
100 | return true; |
101 | } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) { |
102 | return t1.fn_->sequence_nr() < t2.fn_->sequence_nr(); |
103 | } else { |
104 | return t1.getReentrantDepth() < t2.getReentrantDepth(); |
105 | } |
106 | } |
107 | }; |
108 | |
109 | // To notify threads waiting on the ReadyQueue of available tasks on the heap_ |
110 | std::condition_variable not_empty_; |
111 | // To protect read and writes to heap_ |
112 | mutable std::mutex mutex_; |
113 | |
114 | std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> |
115 | heap_; |
116 | |
117 | public: |
118 | // incrementOutstandingTasks indicates whether or not we should increment |
119 | // 'outstanding_tasks_' for the associated GraphTask. This should mostly |
120 | // always be true and is only set false in certain cases (see docs for |
121 | // DistEngine.execute_graph_task_until_ready_queue_empty) |
122 | void push(NodeTask item, bool incrementOutstandingTasks = true); |
123 | void pushShutdownTask(); |
124 | NodeTask pop(); |
125 | bool empty() const; |
126 | size_t size() const; |
127 | }; |
128 | |
129 | // A single instance of this struct should be created through the whole process |
130 | // lifetime. The worker thread creation logic and Engine's destructor rely on |
131 | // this. |
132 | struct TORCH_API Engine { |
133 | /// Returns a reference to a static `Engine` instance. |
134 | static Engine& get_default_engine(); |
135 | |
136 | static Engine& get_base_engine(); |
137 | |
138 | Engine(const Engine&) = delete; |
139 | Engine(Engine&&) = delete; |
140 | virtual ~Engine(); |
141 | |
142 | // Given a list of (Node, input number) pairs computes the value of the graph |
143 | // by following next_edge references. |
144 | virtual variable_list execute( |
145 | const edge_list& roots, |
146 | const variable_list& inputs, |
147 | bool keep_graph, |
148 | bool create_graph, |
149 | bool accumulate_grad, |
150 | const edge_list& outputs = {}); |
151 | |
152 | // Given a pre-populated GraphTask and GraphRoot, computes the backward pass |
153 | // for the graph. |
154 | // |
155 | // NB: This API should only be used by internal autograd specific |
156 | // machinery and shouldn't be exposed to users in anyway. |
157 | virtual c10::intrusive_ptr<at::ivalue::Future> execute_with_graph_task( |
158 | const std::shared_ptr<GraphTask>& graph_task, |
159 | std::shared_ptr<Node> graph_root, |
160 | InputBuffer&& input_buffer); |
161 | |
162 | virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() { |
163 | return std::make_unique<AnomalyMetadata>(); |
164 | } |
165 | |
166 | virtual std::unique_ptr<SavedVariableHooks> get_default_saved_variable_hooks() { |
167 | return nullptr; |
168 | } |
169 | |
170 | // We pass cpu_ready_queue to evaluate_function, so that it knows |
171 | // the correct ready queue to push to after a NodeTask is ready |
172 | void evaluate_function( |
173 | std::shared_ptr<GraphTask>& graph_task, |
174 | Node* func, |
175 | InputBuffer& inputs, |
176 | const std::shared_ptr<ReadyQueue>& cpu_ready_queue); |
177 | |
178 | void initialize_device_threads_pool(); |
179 | virtual void thread_on_exception( |
180 | std::shared_ptr<GraphTask> graph_task, |
181 | const std::shared_ptr<Node>& fn, |
182 | std::exception& e); |
183 | |
184 | void queue_callback(std::function<void()> callback); |
185 | |
186 | bool is_checkpoint_valid(); |
187 | |
188 | // Should be called after fork to notify that worker threads are gone |
189 | void release_workers(); |
190 | |
191 | // Must be called by subclass before destructing to avoid a data-race-on-vptr. |
192 | void stop(); |
193 | |
194 | // Initializes a device thread for the autograd engine. |
195 | virtual void thread_init( |
196 | int device, |
197 | const std::shared_ptr<ReadyQueue>& ready_queue, |
198 | bool should_increment = true); |
199 | |
200 | protected: |
201 | Engine(); |
202 | void compute_dependencies(Node* root, GraphTask& task, uint64_t min_topo_nr); |
203 | |
204 | // initialize the thread local ready queue with the ready queue that is |
205 | // created elsewhere (i.e. thread_init, Engine::execute, etc), or create a new |
206 | // ready queue if ready_queue is not provided. |
207 | void init_local_ready_queue( |
208 | std::shared_ptr<ReadyQueue> ready_queue = nullptr); |
209 | |
210 | std::shared_ptr<ReadyQueue> ready_queue( |
211 | std::shared_ptr<ReadyQueue> cpu_ready_queue, |
212 | at::Device device); |
213 | std::shared_ptr<ReadyQueue> ready_queue_by_index( |
214 | std::shared_ptr<ReadyQueue> cpu_ready_queue, |
215 | int device_index); |
216 | // start device threads (CUDA, XLA, etc.) in Engine, |
217 | // note that it does NOT start CPU thread. |
218 | void start_device_threads(); |
219 | void increment_non_reentrant_thread_count(); |
220 | void decrement_non_reentrant_thread_count(); |
221 | virtual void thread_main(const std::shared_ptr<GraphTask>& task); |
222 | void reentrant_thread_init(); |
223 | void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task); |
224 | |
225 | // Ensures device_ready_queues_ are initialized only once |
226 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
227 | c10::once_flag start_device_threads_flag_; |
228 | // Safe to read device_ready_queues_ without synchronization after |
229 | // initialization |
230 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
231 | std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_; |
232 | |
233 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
234 | std::vector<std::function<void()>> final_callbacks_; |
235 | // To protect reads and writes to final_callbacks_ |
236 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
237 | std::mutex post_callbacks_lock_; |
238 | |
239 | // How many nested reentrant calls are allowed until a new thread is used |
240 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
241 | int max_recursion_depth_; |
242 | |
243 | struct ThreadPoolShared { |
244 | // Data structures used by the threads for executing reentrant backwards |
245 | // tasks. See Note [Reentrant backwards] |
246 | // Number of available threads for processing new GraphTasks. |
247 | unsigned int num_workers_{0}; |
248 | // The threads will wait on work_ to be notified of GraphTasks |
249 | std::condition_variable work_; |
250 | // To protect reads and writes to graphtask_queue_ and num_workers_ |
251 | // and for synchronizing creating new threads when needed |
252 | std::mutex mutex_; |
253 | // Workers will process the GraphTasks added to this queue. A GraphTask is |
254 | // allocated inside Engine::execute and lives for the duration of execute |
255 | std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_; |
256 | |
257 | ThreadPoolShared() = default; |
258 | }; |
259 | |
260 | // Temporary workaround until shutting down threads is done |
261 | // We need shared ownership of all these objects because the threads are |
262 | // leaked when Engine shuts down, so there may be threads waiting on work_ for |
263 | // the graphtasks_queue_ to be nonempty. |
264 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
265 | std::shared_ptr<ThreadPoolShared> thread_pool_shared_; |
266 | |
267 | private: |
268 | // Number of non-reentrant threads |
269 | std::atomic<uint32_t> non_reentrant_device_thread_count_; |
270 | // Destructor will wait for non-reentrant threads to finish |
271 | std::condition_variable non_reentrant_device_thread_condvar_; |
272 | std::mutex non_reentrant_device_thread_mutex_; |
273 | // stop() must be called before the destruction path goes down to the base |
274 | // class, in order to avoid a data-race-on-vptr. Use this boolean to guard |
275 | // whether stop() has already been called, so we can call this in every |
276 | // destructor of the class hierarchy. |
277 | bool stopped_{false}; |
278 | }; |
279 | |
280 | // allow python_engine to override the default engine when it loads |
281 | using EngineStub = Engine& (*)(); |
282 | TORCH_API void set_default_engine_stub(EngineStub stub); |
283 | |
284 | } // namespace autograd |
285 | } // namespace torch |
286 | |