1#pragma once
2
3#include <c10/util/ArrayRef.h>
4#include <torch/csrc/lazy/backend/lowering_context.h>
5#include <torch/csrc/lazy/core/cache.h>
6#include <torch/csrc/lazy/core/ir_util.h>
7#include <torch/csrc/lazy/core/multi_wait.h>
8#include <torch/csrc/lazy/core/tensor.h>
9#include <torch/csrc/lazy/core/util.h>
10
11namespace torch {
12namespace lazy {
13
14class TORCH_API LazyGraphExecutor {
15 public:
16 struct DeviceDataInfo : public BackendData::Info {
17 DeviceDataInfo(int64_t tensor_id, bool read_only)
18 : tensor_id(tensor_id), read_only(read_only) {}
19
20 int64_t tensor_id = 0;
21 bool read_only = false;
22 };
23
24 // Register a lazy graph executor instance that can be retrieved using Get()
25 static void Register(LazyGraphExecutor*);
26 static LazyGraphExecutor* Get();
27
28 virtual ~LazyGraphExecutor() = default;
29
30 // Override these methods to perform custom tensor registration and
31 // unregistration Note: It is vital that the parent implementations are also
32 // called in order for the tensors to show up in the live tensor list
33 virtual void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
34 virtual void UnregisterTensor(LazyTensor::Data* data);
35
36 // Seed for random generator.
37 // Override to supply your own DeviceContextArena.
38 virtual Value GetRngSeed(const BackendDevice& device);
39 virtual uint64_t GetRunningSeed(const BackendDevice& device);
40 virtual void SetRngSeed(const BackendDevice& device, uint64_t seed);
41
42 void DeviceBarrier(const BackendDevice& device);
43
44 BackendDataPtr GetDeviceData(
45 const at::Tensor& tensor,
46 const BackendDevice& device);
47
48 BackendDataPtr GetDeviceData(
49 const at::Scalar& value,
50 at::ScalarType scalar_type,
51 const BackendDevice& device);
52
53 // Retrieves the set of lazy tensors which are currently live in the system,
54 // for the given device. If device is nullptr, the live tensors for all
55 // devices will be returned. Returned tensors are sorted by device as primary
56 // key, and by unique ID as secondary key.
57 std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device);
58
59 // Makes sure that any outstanding IR operation accumulated over live tensors,
60 // gets turned into device data. If wait is true, the sync operation will be
61 // run synchronously. The devices argument, if not empty, tells the devices
62 // which should be partecipating into the replicated computation.
63 virtual void SyncLiveTensorsGraph(
64 const BackendDevice* device,
65 c10::ArrayRef<std::string> devices,
66 bool wait);
67
68 // Applies all the pending IR operations queued over the input tensors. All
69 // the tensors must be on the same device. If wait is true, the sync operation
70 // will be run synchronously. The devices argument, if not empty, tells the
71 // devices which should be partecipating into the replicated computation.
72 void SyncTensorsGraph(
73 std::vector<LazyTensorPtr>* tensors,
74 c10::ArrayRef<std::string> devices,
75 bool wait,
76 bool sync_ltc_data);
77
78 // Marks an execution step, which allows the tensor framework to understand
79 // the computation boundaries.
80 // Override to supply your own DeviceContextArena.
81 virtual void MarkStep(const BackendDevice& device);
82
83 // Waits for all the outstanding operations on all the supplied devices.
84 // If devices is empty, the wait will happen for all local devices.
85 void WaitDeviceOps(c10::ArrayRef<BackendDevice> devices);
86
87 // Retrieves the PyTorch CPU tensors behind the lazy tensors IR operations.
88 // All the tensors must be on the same device.
89 std::vector<at::Tensor> GetTensors(std::vector<LazyTensorPtr>* tensors);
90
91 size_t IncTrimCounter() const;
92
93 // Dumps the backend specific text of the computation accumulated in the graph
94 // which is attached the tensors.
95 std::string DumpBackendComputation(const std::vector<LazyTensorPtr>& tensors);
96
97 Value GetDeviceDataIrValue(
98 const at::Scalar& value,
99 c10::ScalarType type,
100 const BackendDevice& device);
101 Value GetIrValueForScalar(
102 const at::Scalar& value,
103 c10::ScalarType type,
104 const BackendDevice& device);
105 Value GetIrValueForScalar(
106 const at::Scalar& value,
107 const BackendDevice& device);
108
109 // TODO: even though this API is currently used **only** in codegen to
110 // generate real scalar IR values vs scalar tensors, we would like to
111 // use it in other cases where `GetIrValueForXXXScalar` is used, as well
112 // In order to do that, we need to untangle the cases where we don't need
113 // `expand` and where we don't expect a scalar tensor
114 Value GetIrValueForScalarFromCodegen(
115 const at::Scalar& value,
116 const BackendDevice& device);
117 Value GetIrValueForExpandedScalar(
118 const at::Scalar& value,
119 const Shape& shape,
120 const BackendDevice& device);
121
122 struct CachedComputation {
123 explicit CachedComputation(ComputationPtr computation)
124 : computation(std::move(computation)) {}
125
126 ComputationPtr computation;
127 };
128
129 using ComputationCache = Cache<hash_t, CachedComputation, HashReducer>;
130
131 ComputationCache* GetComputationCache();
132
133 hash_t GetGraphHash(const std::vector<LazyTensorPtr>& tensors);
134
135 protected:
136 // TODO(alanwaketan): Revisit if all of them need to be accessible to
137 // derived classes.
138
139 struct SyncTensorsConfig {
140 // Whether we want to force data on the target tensors (hence trimming
141 // the IR graph above them).
142 bool force_ltc_data = true;
143 // Whether when setting the data, the other properties of the tensor
144 // state should be reset.
145 bool sync_ltc_data = true;
146 };
147
148 struct SyncTensorCollection {
149 SyncTensorCollection() : hash(0) {}
150
151 SyncTensorsConfig config;
152 std::vector<size_t> indices;
153 hash_t hash;
154 std::vector<ExceptionCleanup> unlocker;
155 BackendDevice device;
156 };
157
158 struct PostOrderData {
159 std::vector<const Node*> post_order;
160 Util::EmissionMap emission_map;
161 std::vector<BackendDataPtr> parameters_data;
162 std::vector<size_t> parameter_sequence;
163 };
164
165 // Locking:
166 // We perform two kinds of operations of tensors, synchronous and
167 // asynchronous. The ApplyPendingGraph() are synchronous, as we need the
168 // device data result immediately. Before the synchronous operations can
169 // start, they need to wait that the pending asynchronous operations have
170 // completed. Synchronous operations do not hold device locks, since they are
171 // strictly sequential, dictated by the PyTorch execution order. The
172 // SyncTensorsGraph() is asynchronous, and returns immediately after having
173 // scheduled the asynchronous operation. While executing, the asynchronous
174 // operations will hold locks on all the participating devices (in most common
175 // cases there will be only one device).
176 // Since asynchronous operations capture device locks, only one asynchronous
177 // operation can execute at the same time, on a given device. Tensor
178 // operations which send data to device do not need to hold any device locks
179 // while doing so. Only operations which _use_ device data (computations, and
180 // transfer from server) need to wait for asynchronous operations to complete
181 // (barrier).
182
183 class DeviceLocker {
184 public:
185 explicit DeviceLocker(BackendDevice device) : device_(std::move(device)) {}
186
187 const BackendDevice& device() const {
188 return device_;
189 }
190
191 void Lock();
192 void Unlock(std::exception_ptr exptr);
193 void Barrier();
194
195 private:
196 void CheckResetException();
197
198 BackendDevice device_;
199 std::mutex mutex_;
200 std::condition_variable cv_;
201 bool locked_ = false;
202 std::exception_ptr exptr_;
203 };
204
205 class DeviceLockerArena {
206 public:
207 static DeviceLockerArena* Get();
208
209 std::shared_ptr<DeviceLocker> GetLocker(const BackendDevice& device);
210
211 void DeviceBarrier(const BackendDevice& device);
212
213 // Use a set to impose an order on the device locking sequence (ABBA
214 // prevention).
215 std::vector<ExceptionCleanup> LockDevices(
216 const std::set<BackendDevice>& devices);
217
218 private:
219 ExceptionCleanup LockDevice(const BackendDevice& device);
220
221 std::mutex mutex_;
222 std::map<BackendDevice, std::shared_ptr<DeviceLocker>> lockers_;
223 };
224
225 class DataCacheArena {
226 public:
227 static DataCacheArena* Get();
228
229 BackendDataPtr GetDeviceData(
230 const at::Tensor& tensor,
231 const BackendDevice& device);
232
233 BackendDataPtr GetDeviceData(
234 const at::Scalar& value,
235 at::ScalarType scalar_type,
236 const BackendDevice& device);
237
238 private:
239 struct TensorHasher {
240 size_t operator()(const at::Tensor& tensor) const;
241 };
242 struct TensorComparer {
243 bool operator()(const at::Tensor& tensor1, const at::Tensor& tensor2)
244 const;
245 };
246
247 explicit DataCacheArena(size_t max_cache_size);
248
249 using DataCache =
250 Cache<at::Tensor, BackendData, TensorHasher, TensorComparer>;
251
252 DataCache* GetDataCache(const BackendDevice& device);
253
254 size_t max_cache_size_ = 0;
255 std::mutex mutex_;
256 std::map<BackendDevice, std::unique_ptr<DataCache>> device_caches_;
257 };
258
259 // The DeviceContextArena holds per device live information and statistics,
260 // among which the lazy tensors which are currently alive in the system. This
261 // is used to create computation "barriers" in order to flush pending
262 // operations and ensure the same computations are created during the training
263 // loops.
264 // TODO(alanwaketan): Add a registry such that we don't need to make all
265 // related methods virtual.
266 class DeviceContextArena {
267 protected:
268 struct DeviceContext {
269 std::mutex lock;
270 std::map<int64_t, std::weak_ptr<LazyTensor::Data>> tensors_data;
271 uint64_t seed = 101;
272 uint64_t running_seed = 101;
273 Value seed_ir_value;
274 };
275
276 public:
277 static DeviceContextArena* Get();
278 virtual ~DeviceContextArena() = default;
279
280 void RegisterTensor(std::shared_ptr<LazyTensor::Data> data);
281 void UnregisterTensor(LazyTensor::Data* data);
282
283 std::vector<LazyTensorPtr> GetLiveTensors(const BackendDevice* device);
284
285 // Overriding it allow derived class to use their own IRs for Value.
286 virtual Value GetRngSeed(const BackendDevice& device);
287 uint64_t GetRunningSeed(const BackendDevice& device);
288 void SetRngSeed(const BackendDevice& device, uint64_t seed);
289
290 void MarkStep(const BackendDevice& device);
291
292 std::vector<BackendDevice> GetActiveDevices();
293
294 protected:
295 DeviceContext* GetDeviceContext(const BackendDevice& device);
296
297 void ForAllDeviceContexts(
298 const std::function<void(DeviceContext*)>& fn,
299 const BackendDevice* device);
300
301 // Overriding it allow derived class to use their own conversions.
302 virtual Value IrValueFromScalar(
303 const at::Scalar& value,
304 at::ScalarType scalar_type,
305 const BackendDevice& device);
306
307 private:
308 std::vector<DeviceContext*> GetAllDeviceContexts();
309
310 std::mutex lock_;
311 std::map<BackendDevice, DeviceContext*> device_contexts_;
312 };
313
314 struct Async {
315 Async(
316 SyncTensorCollection* coll,
317 std::vector<BackendDataPtr> parameters_data,
318 std::vector<BackendDataPtr> tensors_data,
319 ComputationCache::TypePtr cached_computation);
320 virtual ~Async() = default;
321
322 void Wait();
323
324 MultiWait mwait;
325 std::vector<size_t> indices;
326 std::vector<ExceptionCleanup> unlocker;
327 std::vector<BackendDataPtr> parameters_data;
328 BackendDevice device;
329 ComputationCache::TypePtr cached_computation;
330 std::vector<BackendDataPtr> tensors_data;
331 };
332
333 void ResetTrimCounter() const;
334
335 // Waits for this SyncTensorCollection's device barrier and acquire the lock.
336 virtual void TensorCollectionBarrier(SyncTensorCollection* coll);
337
338 // One can override to insert your own profiler.
339 virtual PostOrderData RunPostOrder(
340 const std::vector<Value>& ir_values,
341 SyncTensorCollection* coll);
342
343 private:
344 struct CompilationResult {
345 BackendDevice device;
346 size_t emitted_nodes = 0;
347 ComputationPtr computation;
348 std::vector<BackendDataPtr> parameters_data;
349 };
350
351 virtual bool ShouldSyncTensor(const LazyTensorPtr tensor) const;
352
353 SyncTensorCollection CollectSyncTensors(
354 const std::vector<LazyTensorPtr>& tensors,
355 const SyncTensorsConfig& config);
356
357 std::vector<Value> CollectRoots(
358 const std::vector<LazyTensorPtr>& tensors,
359 c10::ArrayRef<size_t> indices);
360
361 std::vector<BackendDataPtr> SetTensorData(
362 std::vector<LazyTensorPtr>* tensors,
363 const SyncTensorsConfig& config,
364 c10::ArrayRef<size_t> indices,
365 const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec);
366
367 void ExtractIRAndPrepareTensorData(
368 std::vector<LazyTensorPtr>* tensors,
369 const SyncTensorsConfig& config,
370 c10::ArrayRef<size_t> indices,
371 std::vector<Value>& ir_values,
372 std::vector<BackendDataPtr>& tensor_data_vec);
373
374 std::shared_ptr<Async> TryRunCachedSync(
375 std::vector<LazyTensorPtr>* tensors,
376 SyncTensorCollection* coll,
377 PostOrderData* po_data,
378 const std::vector<BackendDataPtr>& tensor_data_vec);
379
380 CompilationResult Compile(
381 const std::vector<LazyTensorPtr>& tensors,
382 c10::ArrayRef<std::string> devices,
383 const SyncTensorCollection& coll,
384 PostOrderData* po_data,
385 const std::vector<Value>& ir_values);
386
387 ComputationCache::TypePtr LookupCachedCompile(const hash_t& hash);
388
389 std::shared_ptr<Async> SyncTensorsGraphInternal(
390 std::vector<LazyTensorPtr>* tensors,
391 c10::ArrayRef<std::string> devices,
392 const SyncTensorsConfig& config);
393
394 // Schedules the execution of a sync tensors operation in background. The
395 // asynchronous operation will hold the device locks by capturing the ones
396 // present within the coll structure.
397 std::shared_ptr<Async> ScheduleSyncTensorsGraph(
398 SyncTensorCollection* coll,
399 std::vector<BackendDataPtr> parameters_data,
400 std::vector<BackendDataPtr> tensors_data,
401 ComputationCache::TypePtr cached_computation);
402
403 std::shared_ptr<Async> ScheduleSyncTensorsGraph(
404 std::vector<LazyTensorPtr>* tensors,
405 SyncTensorCollection* coll,
406 std::vector<BackendDataPtr> parameters_data,
407 ComputationCache::TypePtr cached_computation,
408 const std::vector<BackendDataPtr>& tensor_data_vec);
409
410 std::vector<at::Tensor> GetTensorsFused(std::vector<LazyTensorPtr>* tensors);
411
412 std::vector<at::Tensor> FetchTensors(
413 std::vector<LazyTensorPtr>* tensors,
414 c10::ArrayRef<BackendDataPtr> tensors_data,
415 const std::vector<size_t>* indices);
416
417 // Gathers the device data for all the input tensors, after an
418 // asynchronous operation.
419 std::vector<BackendDataPtr> GatherTensorsData(
420 const std::vector<LazyTensorPtr>& tensors,
421 c10::ArrayRef<size_t> indices,
422 c10::ArrayRef<BackendDataPtr> tensors_data);
423};
424
425} // namespace lazy
426} // namespace torch
427