1 | #pragma once |
2 | |
3 | #include <c10/util/ArrayRef.h> |
4 | #include <torch/csrc/lazy/backend/lowering_context.h> |
5 | #include <torch/csrc/lazy/core/cache.h> |
6 | #include <torch/csrc/lazy/core/ir_util.h> |
7 | #include <torch/csrc/lazy/core/multi_wait.h> |
8 | #include <torch/csrc/lazy/core/tensor.h> |
9 | #include <torch/csrc/lazy/core/util.h> |
10 | |
11 | namespace torch { |
12 | namespace lazy { |
13 | |
14 | class TORCH_API LazyGraphExecutor { |
15 | public: |
16 | struct DeviceDataInfo : public BackendData::Info { |
17 | DeviceDataInfo(int64_t tensor_id, bool read_only) |
18 | : tensor_id(tensor_id), read_only(read_only) {} |
19 | |
20 | int64_t tensor_id = 0; |
21 | bool read_only = false; |
22 | }; |
23 | |
24 | // Register a lazy graph executor instance that can be retrieved using Get() |
25 | static void Register(LazyGraphExecutor*); |
26 | static LazyGraphExecutor* Get(); |
27 | |
28 | virtual ~LazyGraphExecutor() = default; |
29 | |
30 | // Override these methods to perform custom tensor registration and |
31 | // unregistration Note: It is vital that the parent implementations are also |
32 | // called in order for the tensors to show up in the live tensor list |
33 | virtual void RegisterTensor(std::shared_ptr<LazyTensor::Data> data); |
34 | virtual void UnregisterTensor(LazyTensor::Data* data); |
35 | |
36 | // Seed for random generator. |
37 | // Override to supply your own DeviceContextArena. |
38 | virtual Value GetRngSeed(const BackendDevice& device); |
39 | virtual uint64_t GetRunningSeed(const BackendDevice& device); |
40 | virtual void SetRngSeed(const BackendDevice& device, uint64_t seed); |
41 | |
42 | void DeviceBarrier(const BackendDevice& device); |
43 | |
44 | BackendDataPtr GetDeviceData( |
45 | const at::Tensor& tensor, |
46 | const BackendDevice& device); |
47 | |
48 | BackendDataPtr GetDeviceData( |
49 | const at::Scalar& value, |
50 | at::ScalarType scalar_type, |
51 | const BackendDevice& device); |
52 | |
53 | // Retrieves the set of lazy tensors which are currently live in the system, |
54 | // for the given device. If device is nullptr, the live tensors for all |
55 | // devices will be returned. Returned tensors are sorted by device as primary |
56 | // key, and by unique ID as secondary key. |
57 | std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device); |
58 | |
59 | // Makes sure that any outstanding IR operation accumulated over live tensors, |
60 | // gets turned into device data. If wait is true, the sync operation will be |
61 | // run synchronously. The devices argument, if not empty, tells the devices |
62 | // which should be partecipating into the replicated computation. |
63 | virtual void SyncLiveTensorsGraph( |
64 | const BackendDevice* device, |
65 | c10::ArrayRef<std::string> devices, |
66 | bool wait); |
67 | |
68 | // Applies all the pending IR operations queued over the input tensors. All |
69 | // the tensors must be on the same device. If wait is true, the sync operation |
70 | // will be run synchronously. The devices argument, if not empty, tells the |
71 | // devices which should be partecipating into the replicated computation. |
72 | void SyncTensorsGraph( |
73 | std::vector<LazyTensorPtr>* tensors, |
74 | c10::ArrayRef<std::string> devices, |
75 | bool wait, |
76 | bool sync_ltc_data); |
77 | |
78 | // Marks an execution step, which allows the tensor framework to understand |
79 | // the computation boundaries. |
80 | // Override to supply your own DeviceContextArena. |
81 | virtual void MarkStep(const BackendDevice& device); |
82 | |
83 | // Waits for all the outstanding operations on all the supplied devices. |
84 | // If devices is empty, the wait will happen for all local devices. |
85 | void WaitDeviceOps(c10::ArrayRef<BackendDevice> devices); |
86 | |
87 | // Retrieves the PyTorch CPU tensors behind the lazy tensors IR operations. |
88 | // All the tensors must be on the same device. |
89 | std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors); |
90 | |
91 | size_t IncTrimCounter() const; |
92 | |
93 | // Dumps the backend specific text of the computation accumulated in the graph |
94 | // which is attached the tensors. |
95 | std::string DumpBackendComputation(const std::vector<LazyTensorPtr>& tensors); |
96 | |
97 | Value GetDeviceDataIrValue( |
98 | const at::Scalar& value, |
99 | c10::ScalarType type, |
100 | const BackendDevice& device); |
101 | Value GetIrValueForScalar( |
102 | const at::Scalar& value, |
103 | c10::ScalarType type, |
104 | const BackendDevice& device); |
105 | Value GetIrValueForScalar( |
106 | const at::Scalar& value, |
107 | const BackendDevice& device); |
108 | |
109 | // TODO: even though this API is currently used **only** in codegen to |
110 | // generate real scalar IR values vs scalar tensors, we would like to |
111 | // use it in other cases where `GetIrValueForXXXScalar` is used, as well |
112 | // In order to do that, we need to untangle the cases where we don't need |
113 | // `expand` and where we don't expect a scalar tensor |
114 | Value GetIrValueForScalarFromCodegen( |
115 | const at::Scalar& value, |
116 | const BackendDevice& device); |
117 | Value GetIrValueForExpandedScalar( |
118 | const at::Scalar& value, |
119 | const Shape& shape, |
120 | const BackendDevice& device); |
121 | |
122 | struct CachedComputation { |
123 | explicit CachedComputation(ComputationPtr computation) |
124 | : computation(std::move(computation)) {} |
125 | |
126 | ComputationPtr computation; |
127 | }; |
128 | |
129 | using ComputationCache = Cache<hash_t, CachedComputation, HashReducer>; |
130 | |
131 | ComputationCache* GetComputationCache(); |
132 | |
133 | hash_t GetGraphHash(const std::vector<LazyTensorPtr>& tensors); |
134 | |
135 | protected: |
136 | // TODO(alanwaketan): Revisit if all of them need to be accessible to |
137 | // derived classes. |
138 | |
139 | struct SyncTensorsConfig { |
140 | // Whether we want to force data on the target tensors (hence trimming |
141 | // the IR graph above them). |
142 | bool force_ltc_data = true; |
143 | // Whether when setting the data, the other properties of the tensor |
144 | // state should be reset. |
145 | bool sync_ltc_data = true; |
146 | }; |
147 | |
148 | struct SyncTensorCollection { |
149 | SyncTensorCollection() : hash(0) {} |
150 | |
151 | SyncTensorsConfig config; |
152 | std::vector<size_t> indices; |
153 | hash_t hash; |
154 | std::vector<ExceptionCleanup> unlocker; |
155 | BackendDevice device; |
156 | }; |
157 | |
158 | struct PostOrderData { |
159 | std::vector<const Node*> post_order; |
160 | Util::EmissionMap emission_map; |
161 | std::vector<BackendDataPtr> parameters_data; |
162 | std::vector<size_t> parameter_sequence; |
163 | }; |
164 | |
165 | // Locking: |
166 | // We perform two kinds of operations of tensors, synchronous and |
167 | // asynchronous. The ApplyPendingGraph() are synchronous, as we need the |
168 | // device data result immediately. Before the synchronous operations can |
169 | // start, they need to wait that the pending asynchronous operations have |
170 | // completed. Synchronous operations do not hold device locks, since they are |
171 | // strictly sequential, dictated by the PyTorch execution order. The |
172 | // SyncTensorsGraph() is asynchronous, and returns immediately after having |
173 | // scheduled the asynchronous operation. While executing, the asynchronous |
174 | // operations will hold locks on all the participating devices (in most common |
175 | // cases there will be only one device). |
176 | // Since asynchronous operations capture device locks, only one asynchronous |
177 | // operation can execute at the same time, on a given device. Tensor |
178 | // operations which send data to device do not need to hold any device locks |
179 | // while doing so. Only operations which _use_ device data (computations, and |
180 | // transfer from server) need to wait for asynchronous operations to complete |
181 | // (barrier). |
182 | |
183 | class DeviceLocker { |
184 | public: |
185 | explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {} |
186 | |
187 | const BackendDevice& device() const { |
188 | return device_; |
189 | } |
190 | |
191 | void Lock(); |
192 | void Unlock(std::exception_ptr exptr); |
193 | void Barrier(); |
194 | |
195 | private: |
196 | void CheckResetException(); |
197 | |
198 | BackendDevice device_; |
199 | std::mutex mutex_; |
200 | std::condition_variable cv_; |
201 | bool locked_ = false; |
202 | std::exception_ptr exptr_; |
203 | }; |
204 | |
205 | class DeviceLockerArena { |
206 | public: |
207 | static DeviceLockerArena* Get(); |
208 | |
209 | std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device); |
210 | |
211 | void DeviceBarrier(const BackendDevice& device); |
212 | |
213 | // Use a set to impose an order on the device locking sequence (ABBA |
214 | // prevention). |
215 | std::vector<ExceptionCleanup> LockDevices( |
216 | const std::set<BackendDevice>& devices); |
217 | |
218 | private: |
219 | ExceptionCleanup LockDevice(const BackendDevice& device); |
220 | |
221 | std::mutex mutex_; |
222 | std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_; |
223 | }; |
224 | |
225 | class DataCacheArena { |
226 | public: |
227 | static DataCacheArena* Get(); |
228 | |
229 | BackendDataPtr GetDeviceData( |
230 | const at::Tensor& tensor, |
231 | const BackendDevice& device); |
232 | |
233 | BackendDataPtr GetDeviceData( |
234 | const at::Scalar& value, |
235 | at::ScalarType scalar_type, |
236 | const BackendDevice& device); |
237 | |
238 | private: |
239 | struct TensorHasher { |
240 | size_t operator()(const at::Tensor& tensor) const; |
241 | }; |
242 | struct TensorComparer { |
243 | bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2) |
244 | const; |
245 | }; |
246 | |
247 | explicit DataCacheArena(size_t max_cache_size); |
248 | |
249 | using DataCache = |
250 | Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>; |
251 | |
252 | DataCache* GetDataCache(const BackendDevice& device); |
253 | |
254 | size_t max_cache_size_ = 0; |
255 | std::mutex mutex_; |
256 | std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_; |
257 | }; |
258 | |
259 | // The DeviceContextArena holds per device live information and statistics, |
260 | // among which the lazy tensors which are currently alive in the system. This |
261 | // is used to create computation "barriers" in order to flush pending |
262 | // operations and ensure the same computations are created during the training |
263 | // loops. |
264 | // TODO(alanwaketan): Add a registry such that we don't need to make all |
265 | // related methods virtual. |
266 | class DeviceContextArena { |
267 | protected: |
268 | struct DeviceContext { |
269 | std::mutex lock; |
270 | std::map<int64_t, std::weak_ptr<LazyTensor::Data>> tensors_data; |
271 | uint64_t seed = 101; |
272 | uint64_t running_seed = 101; |
273 | Value seed_ir_value; |
274 | }; |
275 | |
276 | public: |
277 | static DeviceContextArena* Get(); |
278 | virtual ~DeviceContextArena() = default; |
279 | |
280 | void RegisterTensor(std::shared_ptr<LazyTensor::Data> data); |
281 | void UnregisterTensor(LazyTensor::Data* data); |
282 | |
283 | std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device); |
284 | |
285 | // Overriding it allow derived class to use their own IRs for Value. |
286 | virtual Value GetRngSeed(const BackendDevice& device); |
287 | uint64_t GetRunningSeed(const BackendDevice& device); |
288 | void SetRngSeed(const BackendDevice& device, uint64_t seed); |
289 | |
290 | void MarkStep(const BackendDevice& device); |
291 | |
292 | std::vector<BackendDevice> GetActiveDevices(); |
293 | |
294 | protected: |
295 | DeviceContext* GetDeviceContext(const BackendDevice& device); |
296 | |
297 | void ForAllDeviceContexts( |
298 | const std::function<void(DeviceContext*)>& fn, |
299 | const BackendDevice* device); |
300 | |
301 | // Overriding it allow derived class to use their own conversions. |
302 | virtual Value IrValueFromScalar( |
303 | const at::Scalar& value, |
304 | at::ScalarType scalar_type, |
305 | const BackendDevice& device); |
306 | |
307 | private: |
308 | std::vector<DeviceContext*> GetAllDeviceContexts(); |
309 | |
310 | std::mutex lock_; |
311 | std::map<BackendDevice, DeviceContext*> device_contexts_; |
312 | }; |
313 | |
314 | struct Async { |
315 | Async( |
316 | SyncTensorCollection* coll, |
317 | std::vector<BackendDataPtr> parameters_data, |
318 | std::vector<BackendDataPtr> tensors_data, |
319 | ComputationCache::TypePtr cached_computation); |
320 | virtual ~Async() = default; |
321 | |
322 | void Wait(); |
323 | |
324 | MultiWait mwait; |
325 | std::vector<size_t> indices; |
326 | std::vector<ExceptionCleanup> unlocker; |
327 | std::vector<BackendDataPtr> parameters_data; |
328 | BackendDevice device; |
329 | ComputationCache::TypePtr cached_computation; |
330 | std::vector<BackendDataPtr> tensors_data; |
331 | }; |
332 | |
333 | void ResetTrimCounter() const; |
334 | |
335 | // Waits for this SyncTensorCollection's device barrier and acquire the lock. |
336 | virtual void TensorCollectionBarrier(SyncTensorCollection* coll); |
337 | |
338 | // One can override to insert your own profiler. |
339 | virtual PostOrderData RunPostOrder( |
340 | const std::vector<Value>& ir_values, |
341 | SyncTensorCollection* coll); |
342 | |
343 | private: |
344 | struct CompilationResult { |
345 | BackendDevice device; |
346 | size_t emitted_nodes = 0; |
347 | ComputationPtr computation; |
348 | std::vector<BackendDataPtr> parameters_data; |
349 | }; |
350 | |
351 | virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const; |
352 | |
353 | SyncTensorCollection CollectSyncTensors( |
354 | const std::vector<LazyTensorPtr>& tensors, |
355 | const SyncTensorsConfig& config); |
356 | |
357 | std::vector<Value> CollectRoots( |
358 | const std::vector<LazyTensorPtr>& tensors, |
359 | c10::ArrayRef<size_t> indices); |
360 | |
361 | std::vector<BackendDataPtr> SetTensorData( |
362 | std::vector<LazyTensorPtr>* tensors, |
363 | const SyncTensorsConfig& config, |
364 | c10::ArrayRef<size_t> indices, |
365 | const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec); |
366 | |
367 | void ExtractIRAndPrepareTensorData( |
368 | std::vector<LazyTensorPtr>* tensors, |
369 | const SyncTensorsConfig& config, |
370 | c10::ArrayRef<size_t> indices, |
371 | std::vector<Value>& ir_values, |
372 | std::vector<BackendDataPtr>& tensor_data_vec); |
373 | |
374 | std::shared_ptr<Async> TryRunCachedSync( |
375 | std::vector<LazyTensorPtr>* tensors, |
376 | SyncTensorCollection* coll, |
377 | PostOrderData* po_data, |
378 | const std::vector<BackendDataPtr>& tensor_data_vec); |
379 | |
380 | CompilationResult Compile( |
381 | const std::vector<LazyTensorPtr>& tensors, |
382 | c10::ArrayRef<std::string> devices, |
383 | const SyncTensorCollection& coll, |
384 | PostOrderData* po_data, |
385 | const std::vector<Value>& ir_values); |
386 | |
387 | ComputationCache::TypePtr LookupCachedCompile(const hash_t& hash); |
388 | |
389 | std::shared_ptr<Async> SyncTensorsGraphInternal( |
390 | std::vector<LazyTensorPtr>* tensors, |
391 | c10::ArrayRef<std::string> devices, |
392 | const SyncTensorsConfig& config); |
393 | |
394 | // Schedules the execution of a sync tensors operation in background. The |
395 | // asynchronous operation will hold the device locks by capturing the ones |
396 | // present within the coll structure. |
397 | std::shared_ptr<Async> ScheduleSyncTensorsGraph( |
398 | SyncTensorCollection* coll, |
399 | std::vector<BackendDataPtr> parameters_data, |
400 | std::vector<BackendDataPtr> tensors_data, |
401 | ComputationCache::TypePtr cached_computation); |
402 | |
403 | std::shared_ptr<Async> ScheduleSyncTensorsGraph( |
404 | std::vector<LazyTensorPtr>* tensors, |
405 | SyncTensorCollection* coll, |
406 | std::vector<BackendDataPtr> parameters_data, |
407 | ComputationCache::TypePtr cached_computation, |
408 | const std::vector<BackendDataPtr>& tensor_data_vec); |
409 | |
410 | std::vector<at::Tensor> GetTensorsFused(std::vector<LazyTensorPtr>* tensors); |
411 | |
412 | std::vector<at::Tensor> FetchTensors( |
413 | std::vector<LazyTensorPtr>* tensors, |
414 | c10::ArrayRef<BackendDataPtr> tensors_data, |
415 | const std::vector<size_t>* indices); |
416 | |
417 | // Gathers the device data for all the input tensors, after an |
418 | // asynchronous operation. |
419 | std::vector<BackendDataPtr> GatherTensorsData( |
420 | const std::vector<LazyTensorPtr>& tensors, |
421 | c10::ArrayRef<size_t> indices, |
422 | c10::ArrayRef<BackendDataPtr> tensors_data); |
423 | }; |
424 | |
425 | } // namespace lazy |
426 | } // namespace torch |
427 | |