1#include <torch/csrc/lazy/core/lazy_graph_executor.h>
2
3#include <ATen/ScalarOps.h>
4#include <c10/util/Logging.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/lazy/core/config.h>
8#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
9#include <torch/csrc/lazy/core/ir_dump_util.h>
10#include <torch/csrc/lazy/core/ir_util.h>
11#include <torch/csrc/lazy/core/tensor_util.h>
12#include <torch/csrc/lazy/core/unique.h>
13
14#include <torch/csrc/lazy/core/debug_util.h>
15#include <torch/csrc/lazy/core/ir_builder.h>
16#include <torch/csrc/lazy/core/metrics.h>
17#include <torch/csrc/lazy/core/ops/arithmetic_ir_ops.h>
18#include <torch/csrc/lazy/core/thread_pool.h>
19
20#include <ATen/ScalarOps.h>
21
22namespace torch {
23namespace lazy {
24namespace {
25
26struct TlsData {
27 void Reset() {
28 trim_counter = 0;
29 }
30
31 size_t trim_counter = 0;
32};
33
34thread_local TlsData g_tls_data;
35
36bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2) {
37 if (t1.scalar_type() != t2.scalar_type() || t1.sizes() != t2.sizes()) {
38 return false;
39 }
40 // PyTorch currently has an issue comparing tensors which have NaN values in
41 // it. The compare is not deterministic. So we do memory compare here until
42 // the PyTorch equal() API is fixed.
43 at::Tensor contiguous_t1 = t1.contiguous();
44 at::Tensor contiguous_t2 = t2.contiguous();
45 return std::memcmp(
46 contiguous_t1.data_ptr(),
47 contiguous_t2.data_ptr(),
48 contiguous_t1.numel() * contiguous_t1.itemsize()) == 0;
49}
50
51// Return true if any tensor in the list has an underlying IR (leaf or
52// operation).
53bool TensorsHaveIR(const std::vector<LazyTensorPtr>& tensors) {
54 for (const auto& tensor : tensors) {
55 if (tensor->CurrentDataHandle() || tensor->CurrentIrValue()) {
56 return true;
57 }
58 }
59 return false;
60}
61
62std::atomic<LazyGraphExecutor*> lazy_graph_executor_registry;
63} // namespace
64
65auto LazyGraphExecutor::DeviceContextArena::Get()
66 -> LazyGraphExecutor::DeviceContextArena* {
67 static DeviceContextArena* arena = new DeviceContextArena();
68 return arena;
69}
70
71void LazyGraphExecutor::DeviceContextArena::RegisterTensor(
72 std::shared_ptr<LazyTensor::Data> data) {
73 DeviceContext* devctx = GetDeviceContext(data->device);
74 std::lock_guard<std::mutex> lock(devctx->lock);
75 devctx->tensors_data.emplace(data->unique_id, data);
76}
77
78void LazyGraphExecutor::DeviceContextArena::UnregisterTensor(
79 LazyTensor::Data* data) {
80 DeviceContext* devctx = GetDeviceContext(data->device);
81 std::lock_guard<std::mutex> lock(devctx->lock);
82 devctx->tensors_data.erase(data->unique_id);
83}
84
85std::vector<LazyTensorPtr> LazyGraphExecutor::DeviceContextArena::
86 GetLiveTensors(const BackendDevice* device) {
87 std::vector<LazyTensorPtr> tensors;
88 auto fn = [&](DeviceContext* devctx) {
89 std::lock_guard<std::mutex> lock(devctx->lock);
90 for (auto& uid_wptr : devctx->tensors_data) {
91 std::shared_ptr<LazyTensor::Data> data = uid_wptr.second.lock();
92 if (data != nullptr) {
93 tensors.push_back(LazyTensor::Create(std::move(data)));
94 }
95 }
96 };
97 ForAllDeviceContexts(fn, device);
98 return tensors;
99}
100
101Value LazyGraphExecutor::DeviceContextArena::GetRngSeed(
102 const BackendDevice& device) {
103 static const at::ScalarType kSeedType = at::ScalarType::Long;
104 static const uint64_t kSeedMul = 214013;
105 static const uint64_t kSeedAdd = 2531011;
106 DeviceContext* devctx = GetDeviceContext(device);
107 std::lock_guard<std::mutex> lock(devctx->lock);
108 if (!devctx->seed_ir_value) {
109 devctx->seed_ir_value =
110 IrValueFromScalar(MakeIntScalar(devctx->seed), kSeedType, device);
111 }
112 // Keep the running seed as scalar as well, so we can return it directly
113 // without executing graphs.
114 devctx->running_seed = kSeedAdd + kSeedMul * devctx->running_seed;
115 // Compose new seeds from the root seed, to avoid creating too many
116 // computation parameters which might overflow the device capacity.
117 Value k = MakeScalar(MakeIntScalar(kSeedMul), kSeedType);
118 Value b = MakeScalar(MakeIntScalar(kSeedAdd), kSeedType);
119 devctx->seed_ir_value = b + k * devctx->seed_ir_value;
120 return devctx->seed_ir_value;
121}
122
123uint64_t LazyGraphExecutor::DeviceContextArena::GetRunningSeed(
124 const BackendDevice& device) {
125 DeviceContext* devctx = GetDeviceContext(device);
126 std::lock_guard<std::mutex> lock(devctx->lock);
127 return devctx->running_seed;
128}
129
130void LazyGraphExecutor::DeviceContextArena::SetRngSeed(
131 const BackendDevice& device,
132 uint64_t seed) {
133 DeviceContext* devctx = GetDeviceContext(device);
134 std::lock_guard<std::mutex> lock(devctx->lock);
135 devctx->seed = seed;
136 devctx->running_seed = devctx->seed;
137 devctx->seed_ir_value = Value();
138}
139
140void LazyGraphExecutor::DeviceContextArena::MarkStep(
141 const BackendDevice& device) {
142 DeviceContext* devctx = GetDeviceContext(device);
143 std::lock_guard<std::mutex> lock(devctx->lock);
144 devctx->seed = 1012031 + devctx->seed * 7012063;
145 devctx->running_seed = devctx->seed;
146 devctx->seed_ir_value = Value();
147}
148
149std::vector<BackendDevice> LazyGraphExecutor::DeviceContextArena::
150 GetActiveDevices() {
151 std::vector<BackendDevice> active_devices;
152 std::lock_guard<std::mutex> lock(lock_);
153 active_devices.reserve(device_contexts_.size());
154 for (auto& device_contexts : device_contexts_) {
155 active_devices.push_back(device_contexts.first);
156 }
157 return active_devices;
158}
159
160auto LazyGraphExecutor::DeviceContextArena::GetAllDeviceContexts()
161 -> std::vector<DeviceContext*> {
162 std::vector<DeviceContext*> all_device_contexts;
163 std::lock_guard<std::mutex> lock(lock_);
164 all_device_contexts.reserve(device_contexts_.size());
165 for (auto& device_contexts : device_contexts_) {
166 all_device_contexts.push_back(device_contexts.second);
167 }
168 return all_device_contexts;
169}
170
171void LazyGraphExecutor::DeviceContextArena::ForAllDeviceContexts(
172 const std::function<void(DeviceContext*)>& fn,
173 const BackendDevice* device) {
174 if (device == nullptr) {
175 for (auto devctx : GetAllDeviceContexts()) {
176 fn(devctx);
177 }
178 } else {
179 fn(GetDeviceContext(*device));
180 }
181}
182
183auto LazyGraphExecutor::DeviceContextArena::GetDeviceContext(
184 const BackendDevice& device) -> DeviceContext* {
185 std::lock_guard<std::mutex> lock(lock_);
186 auto it = device_contexts_.find(device);
187 if (it == device_contexts_.end()) {
188 it = device_contexts_.emplace(device, new DeviceContext()).first;
189 }
190 return it->second;
191}
192
193Value LazyGraphExecutor::DeviceContextArena::IrValueFromScalar(
194 const at::Scalar& value,
195 at::ScalarType scalar_type,
196 const BackendDevice& device) {
197 at::Tensor tensor = at::scalar_tensor(value, at::TensorOptions(scalar_type));
198 BackendDataPtr device_data = TensorToDataHandle(tensor, device);
199 return MakeDeviceData(std::move(device_data));
200}
201
202void LazyGraphExecutor::DeviceLocker::Lock() {
203 std::unique_lock<std::mutex> lock(mutex_);
204 cv_.wait(lock, [this] { return !locked_; });
205 CheckResetException();
206 locked_ = true;
207}
208
209void LazyGraphExecutor::DeviceLocker::Unlock(std::exception_ptr exptr) {
210 std::lock_guard<std::mutex> lock(mutex_);
211 locked_ = false;
212 exptr_ = std::move(exptr);
213 cv_.notify_all();
214}
215
216void LazyGraphExecutor::DeviceLocker::Barrier() {
217 std::unique_lock<std::mutex> lock(mutex_);
218 cv_.wait(lock, [this] { return !locked_; });
219 cv_.notify_all();
220 CheckResetException();
221}
222
223void LazyGraphExecutor::DeviceLocker::CheckResetException() {
224 std::exception_ptr exptr = std::move(exptr_);
225 exptr_ = nullptr;
226 if (exptr != nullptr) {
227 std::rethrow_exception(exptr);
228 }
229}
230
231auto LazyGraphExecutor::DeviceLockerArena::Get() -> DeviceLockerArena* {
232 static DeviceLockerArena* arena = new DeviceLockerArena();
233 return arena;
234}
235
236auto LazyGraphExecutor::DeviceLockerArena::GetLocker(
237 const BackendDevice& device) -> std::shared_ptr<DeviceLocker> {
238 std::lock_guard<std::mutex> lock(mutex_);
239 auto it = lockers_.find(device);
240 if (it == lockers_.end()) {
241 it = lockers_.emplace(device, std::make_shared<DeviceLocker>(device)).first;
242 }
243 return it->second;
244}
245
246void LazyGraphExecutor::DeviceLockerArena::DeviceBarrier(
247 const BackendDevice& device) {
248 auto locker = DeviceLockerArena::Get()->GetLocker(device);
249 locker->Barrier();
250}
251
252std::vector<ExceptionCleanup> LazyGraphExecutor::DeviceLockerArena::LockDevices(
253 const std::set<BackendDevice>& devices) {
254 std::vector<ExceptionCleanup> unlocker;
255 unlocker.reserve(devices.size());
256 for (auto& device : devices) {
257 unlocker.emplace_back(LockDevice(device));
258 }
259 return unlocker;
260}
261
262ExceptionCleanup LazyGraphExecutor::DeviceLockerArena::LockDevice(
263 const BackendDevice& device) {
264 VLOG(4) << "Waiting on device barrier for device " << device << " ...";
265 std::shared_ptr<DeviceLocker> locker;
266 {
267 TORCH_LAZY_TIMED("DeviceLockWait");
268 locker = DeviceLockerArena::Get()->GetLocker(device);
269 locker->Lock();
270 }
271 VLOG(4) << "Waiting on device barrier for device " << device << " done!";
272 return torch::lazy::ExceptionCleanup(
273 [locker = std::move(locker)](
274 torch::lazy::ExceptionCleanup::StatusType status) {
275 locker->Unlock(std::move(status));
276 });
277}
278
279auto LazyGraphExecutor::DataCacheArena::Get() -> DataCacheArena* {
280 static DataCacheArena* arena =
281 new DataCacheArena(FLAGS_torch_lazy_device_data_cache_size);
282 return arena;
283}
284
285LazyGraphExecutor::DataCacheArena::DataCacheArena(size_t max_cache_size)
286 : max_cache_size_(max_cache_size) {}
287
288BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
289 const at::Tensor& tensor,
290 const BackendDevice& device) {
291 DataCacheArena::DataCache* cache = Get()->GetDataCache(device);
292 ;
293 BackendDataPtr device_data = cache->Get(tensor);
294 if (device_data == nullptr) {
295 at::Tensor tensor_copy = CopyTensor(tensor);
296 device_data = TensorToDataHandle(tensor_copy, device);
297 cache->Add(std::move(tensor_copy), device_data);
298 TORCH_LAZY_COUNTER("DeviceDataCacheMiss", 1);
299 }
300 return device_data;
301}
302
303BackendDataPtr LazyGraphExecutor::DataCacheArena::GetDeviceData(
304 const at::Scalar& value,
305 at::ScalarType scalar_type,
306 const BackendDevice& device) {
307 // Workaround since at::scalar_tensor doesn't support bfloat16 yet.
308 at::Tensor t = at::scalar_tensor(
309 value,
310 at::TensorOptions(
311 scalar_type == at::ScalarType::BFloat16 ? at::ScalarType::Float
312 : scalar_type));
313 if (scalar_type == at::ScalarType::BFloat16) {
314 t = t.to(scalar_type);
315 }
316 return GetDeviceData(t, device);
317}
318
319size_t LazyGraphExecutor::DataCacheArena::TensorHasher::operator()(
320 const at::Tensor& tensor) const {
321 return HashReduce(
322 HashCombine(GetEnumValue(tensor.scalar_type()), TensorHash(tensor)));
323}
324
325bool LazyGraphExecutor::DataCacheArena::TensorComparer::operator()(
326 const at::Tensor& tensor1,
327 const at::Tensor& tensor2) const {
328 return TensorCompare(tensor1, tensor2);
329}
330
331auto LazyGraphExecutor::DataCacheArena::GetDataCache(
332 const BackendDevice& device) -> DataCache* {
333 std::lock_guard<std::mutex> lock(mutex_);
334 auto it = device_caches_.find(device);
335 if (it == device_caches_.end()) {
336 std::unique_ptr<DataCache> cache(new DataCache(max_cache_size_));
337 it = device_caches_.emplace(device, std::move(cache)).first;
338 }
339 return it->second.get();
340}
341
342void LazyGraphExecutor::Register(LazyGraphExecutor* executor) {
343 lazy_graph_executor_registry.store(executor);
344}
345LazyGraphExecutor* LazyGraphExecutor::Get() {
346 auto* executor = lazy_graph_executor_registry.load();
347 TORCH_CHECK(executor, "Lazy graph executor not registered.");
348 return executor;
349}
350
351void LazyGraphExecutor::RegisterTensor(std::shared_ptr<LazyTensor::Data> data) {
352 DeviceContextArena::Get()->RegisterTensor(data);
353 TORCH_LAZY_COUNTER("CreateLtcTensor", 1);
354}
355
356void LazyGraphExecutor::UnregisterTensor(LazyTensor::Data* data) {
357 DeviceContextArena::Get()->UnregisterTensor(data);
358 TORCH_LAZY_COUNTER("DestroyLtcTensor", 1);
359}
360
361Value LazyGraphExecutor::GetRngSeed(const BackendDevice& device) {
362 return DeviceContextArena::Get()->GetRngSeed(device);
363}
364
365uint64_t LazyGraphExecutor::GetRunningSeed(const BackendDevice& device) {
366 return DeviceContextArena::Get()->GetRunningSeed(device);
367}
368
369void LazyGraphExecutor::SetRngSeed(const BackendDevice& device, uint64_t seed) {
370 DeviceContextArena::Get()->SetRngSeed(device, seed);
371}
372
373void LazyGraphExecutor::DeviceBarrier(const BackendDevice& device) {
374 DeviceLockerArena::Get()->DeviceBarrier(device);
375}
376
377BackendDataPtr LazyGraphExecutor::GetDeviceData(
378 const at::Tensor& tensor,
379 const BackendDevice& device) {
380 return DataCacheArena::Get()->GetDeviceData(tensor, device);
381}
382
383BackendDataPtr LazyGraphExecutor::GetDeviceData(
384 const at::Scalar& value,
385 at::ScalarType scalar_type,
386 const BackendDevice& device) {
387 return DataCacheArena::Get()->GetDeviceData(value, scalar_type, device);
388}
389
390std::vector<LazyTensorPtr> LazyGraphExecutor::GetLiveTensors(
391 const BackendDevice* device) {
392 return DeviceContextArena::Get()->GetLiveTensors(device);
393}
394
395void LazyGraphExecutor::SyncLiveTensorsGraph(
396 const BackendDevice* device,
397 c10::ArrayRef<std::string> devices,
398 bool wait) {
399 auto tensors = GetLiveTensors(device);
400 VLOG(4) << tensors.size() << " live tensors: devices=("
401 << c10::Join(", ", devices) << ")";
402 SyncTensorsGraph(&tensors, devices, wait, /*sync_ltc_data=*/true);
403}
404
405void LazyGraphExecutor::SyncTensorsGraph(
406 std::vector<LazyTensorPtr>* tensors,
407 c10::ArrayRef<std::string> devices,
408 bool wait,
409 bool sync_ltc_data) {
410 VLOG(4) << "Trying to sync the value of " << tensors->size() << " tensor(s)";
411 SyncTensorsConfig config;
412 config.sync_ltc_data = sync_ltc_data;
413
414 auto async = SyncTensorsGraphInternal(tensors, devices, config);
415 if (FLAGS_torch_lazy_use_thread_pool && wait && async != nullptr) {
416 async->mwait.Wait();
417 }
418}
419
420void LazyGraphExecutor::MarkStep(const BackendDevice& device) {
421 TORCH_LAZY_COUNTER("MarkStep", 1);
422 DeviceContextArena::Get()->MarkStep(device);
423 ScopePusher::ResetScopes();
424 ResetTrimCounter();
425 // Move TrieCache's current pointer back to its root
426 TrieCache::Get()->ResetCurrent();
427}
428
429void LazyGraphExecutor::WaitDeviceOps(c10::ArrayRef<BackendDevice> devices) {
430 std::set<BackendDevice> wait_devices;
431 if (!devices.empty()) {
432 for (auto& device : devices) {
433 wait_devices.insert(device);
434 }
435 } else {
436 for (auto& device_str : DeviceContextArena::Get()->GetActiveDevices()) {
437 // TODO: Remove the last use of Device(const std::string& device_spec).
438 wait_devices.insert(BackendDevice(device_str));
439 }
440 }
441 // The LockDevices() API returns a vector of
442 // ExceptionCleanup object, which is going to be freed
443 // immediately, turning this operation into a lock barrier.
444 // NOLINTNEXTLINE
445 DeviceLockerArena::Get()->LockDevices(wait_devices);
446}
447
448std::vector<at::Tensor> LazyGraphExecutor::GetTensors(
449 std::vector<LazyTensorPtr>* tensors) {
450 VLOG(4) << "Trying to get the value of " << tensors->size() << " tensor(s)";
451 return GetTensorsFused(tensors);
452}
453
454void LazyGraphExecutor::ResetTrimCounter() const {
455 g_tls_data.Reset();
456}
457
458size_t LazyGraphExecutor::IncTrimCounter() const {
459 return ++g_tls_data.trim_counter;
460}
461
462std::string LazyGraphExecutor::DumpBackendComputation(
463 const std::vector<LazyTensorPtr>& tensors) {
464 std::vector<Value> ir_values;
465 for (auto& tensor : tensors) {
466 Value ir_value = tensor->CurrentIrValue();
467 if (ir_value) {
468 ir_values.push_back(std::move(ir_value));
469 }
470 }
471 return !ir_values.empty() ? DumpUtil::ToBackend(ir_values, BackendDevice())
472 : std::string();
473}
474
475Value LazyGraphExecutor::GetDeviceDataIrValue(
476 const at::Scalar& value,
477 c10::ScalarType type,
478 const BackendDevice& device) {
479 BackendDataPtr data = GetDeviceData(value, type, device);
480 data->SetInfo(std::make_shared<DeviceDataInfo>(
481 /*tensor_id=*/-1, /*read_only=*/true));
482 return MakeDeviceData(std::move(data));
483}
484
485Value LazyGraphExecutor::GetIrValueForScalarFromCodegen(
486 const at::Scalar& value,
487 const BackendDevice& device) {
488 if (IsSpecialScalar(value)) {
489 return MakeScalar(value, value.type());
490 }
491 auto data = GetDeviceData(value, value.type(), device);
492 data->SetInfo(
493 std::make_shared<DeviceDataInfo>(/*tensor_id=*/-1, /*read_only=*/true));
494 return MakeDeviceData(std::move(data));
495}
496
497Value LazyGraphExecutor::GetIrValueForScalar(
498 const at::Scalar& value,
499 c10::ScalarType type,
500 const BackendDevice& device) {
501 if (IsSpecialScalar(value)) {
502 return MakeScalar(value, type);
503 }
504 return GetDeviceDataIrValue(value, type, device);
505}
506
507Value LazyGraphExecutor::GetIrValueForScalar(
508 const at::Scalar& value,
509 const BackendDevice& device) {
510 return GetIrValueForScalar(value, value.type(), device);
511}
512
513Value LazyGraphExecutor::GetIrValueForExpandedScalar(
514 const at::Scalar& value,
515 const Shape& shape,
516 const BackendDevice& device) {
517 c10::ArrayRef<int64_t> dimensions = shape.sizes();
518 auto type = shape.scalar_type();
519 Value ir_value = GetIrValueForScalar(value, type, device);
520 if (!dimensions.empty()) {
521 ir_value = MakeExpand(
522 ir_value,
523 dimensions.vec(),
524 /*is_scalar_expand=*/true);
525 }
526 return ir_value;
527}
528
529LazyGraphExecutor::Async::Async(
530 SyncTensorCollection* coll,
531 std::vector<BackendDataPtr> parameters_data,
532 std::vector<BackendDataPtr> tensors_data,
533 ComputationCache::TypePtr cached_computation)
534 : mwait(1),
535 indices(std::move(coll->indices)),
536 unlocker(std::move(coll->unlocker)),
537 parameters_data(std::move(parameters_data)),
538 device(coll->device),
539 cached_computation(std::move(cached_computation)),
540 tensors_data(std::move(tensors_data)) {}
541
542void LazyGraphExecutor::Async::Wait() {
543 mwait.Wait();
544 // Accessing other Async members is safe only after MultiWait::Wait()
545 // completes.
546 ExceptionCleanup::StatusType status;
547 for (auto& cleanup : unlocker) {
548 const ExceptionCleanup::StatusType& cleanup_status = cleanup.GetStatus();
549 if (cleanup_status != nullptr) {
550 if (status == nullptr) {
551 status = cleanup_status;
552 }
553 // If we observe the status here, no need to let it propagate to the next
554 // device lock operation.
555 cleanup.SetStatus(nullptr);
556 }
557 }
558 if (status != nullptr) {
559 std::rethrow_exception(status);
560 }
561}
562
563bool LazyGraphExecutor::ShouldSyncTensor(const LazyTensorPtr tensor) const {
564 return tensor->GetIrValue()->op() != ltc_not_supported;
565}
566
567LazyGraphExecutor::SyncTensorCollection LazyGraphExecutor::CollectSyncTensors(
568 const std::vector<LazyTensorPtr>& tensors,
569 const SyncTensorsConfig& config) {
570 Unique<BackendDevice> unique_device;
571 for (const auto& tensor : tensors) {
572 unique_device.set(tensor->GetDevice());
573 }
574 SyncTensorCollection coll;
575 if (!unique_device) {
576 return coll;
577 }
578 if (!config.force_ltc_data && !TensorsHaveIR(tensors)) {
579 return coll;
580 }
581
582 std::vector<at::Tensor> at_tensors;
583 std::vector<BackendDevice> devices;
584 std::vector<size_t> at_tensor_index;
585 std::unordered_set<int64_t> tensor_ids;
586 // The force_ltc_data controls aliasing compilation, so effectively the same
587 // graph with on/off force_ltc_data should not match, hash wise.
588 coll.hash = MHash(config.force_ltc_data);
589 coll.config = config;
590 coll.device = *unique_device;
591 coll.indices.reserve(tensors.size());
592
593 for (const auto i : c10::irange(tensors.size())) {
594 if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
595 tensors[i]->CurrentDataHandle() == nullptr) {
596 Value ir_value = tensors[i]->CurrentIrValue();
597 if (ir_value) {
598 if (ShouldSyncTensor(tensors[i])) {
599 // Add only tensors which need to be synced.
600 coll.hash = HashCombine(coll.hash, ir_value.hash());
601 coll.indices.push_back(i);
602 }
603 } else if (config.force_ltc_data) {
604 // The tensor only has at::Tensor data. We need to queue it for a
605 // device upload.
606 c10::optional<at::Tensor> tensor_data = tensors[i]->CurrentTensorData();
607 TORCH_CHECK(tensor_data);
608 at_tensors.push_back(*tensor_data);
609 devices.push_back(tensors[i]->GetDevice());
610 at_tensor_index.push_back(i);
611 }
612 }
613 }
614 if (!at_tensors.empty()) {
615 TORCH_LAZY_COUNTER("SyncTensorsToData", at_tensors.size());
616 std::vector<BackendDataPtr> handles =
617 CreateTensorsData(at_tensors, devices);
618 for (const auto i : c10::irange(handles.size())) {
619 // If we are here, it means that the IR Value for the tensor is not
620 // present. Also, we uploaded the at::Tensor data to the device, but such
621 // data is still valid so we leave it live on the lazy tensor (so that a
622 // following ToTensor() does not need to fetch it from device).
623 tensors[at_tensor_index[i]]->data()->handle = std::move(handles[i]);
624 }
625 }
626 VLOG(4) << "Tensors graph hash " << HashToString(coll.hash) << " on device "
627 << coll.device;
628 return coll;
629}
630
631std::vector<Value> LazyGraphExecutor::CollectRoots(
632 const std::vector<LazyTensorPtr>& tensors,
633 c10::ArrayRef<size_t> indices) {
634 std::vector<Value> roots;
635 roots.reserve(indices.size());
636 for (auto index : indices) {
637 roots.push_back(tensors.at(index)->CurrentIrValue());
638 }
639 return roots;
640}
641
642void LazyGraphExecutor::ExtractIRAndPrepareTensorData(
643 std::vector<LazyTensorPtr>* tensors,
644 const SyncTensorsConfig& config,
645 c10::ArrayRef<size_t> indices,
646 std::vector<Value>& ir_values,
647 std::vector<BackendDataPtr>& tensor_data_vec) {
648 ir_values.reserve(indices.size());
649 tensor_data_vec.reserve(indices.size());
650 for (auto index : indices) {
651 LazyTensorPtr& tensor = (*tensors)[index];
652 Value ir_value = tensor->CurrentIrValue();
653 ir_values.push_back(ir_value);
654 const BackendDevice& tensor_device = tensor->GetDevice();
655 BackendDataPtr handle = getBackend()->CreateDataPlaceholder(
656 tensor_device, std::move(tensor->shape()));
657 tensor_data_vec.push_back(handle);
658 if (tensor->CurrentDataHandle() == nullptr && config.sync_ltc_data) {
659 tensor->AssignIrValue(Value());
660 }
661 }
662}
663
664std::vector<torch::lazy::BackendDataPtr> LazyGraphExecutor::SetTensorData(
665 std::vector<LazyTensorPtr>* tensors,
666 const SyncTensorsConfig& config,
667 c10::ArrayRef<size_t> indices,
668 const std::vector<BackendDataPtr>& tensor_data_vec) {
669 std::vector<BackendDataPtr> tensors_data;
670 tensors_data.reserve(indices.size());
671 for (int i = 0; i < indices.size(); i++) {
672 auto index = indices[i];
673 LazyTensorPtr& tensor = (*tensors)[index];
674 // If the config.force_ltc_data flag is true, the purpose of this tensor
675 // sync operation is to truncate the IR graph and materialize device data in
676 // place of IR graph, on selected tensors. But since operation will complete
677 // asynchronously, if a tensor does not already have device data, we need to
678 // install a placeholder. Since at this point we hold a lock on the device
679 // where the tensors reside (locks held within the coll structure, and moved
680 // into the async variable), any other operation trying to access the
681 // tensor's device data will have to wait until the asynchronous operation
682 // completes.
683 BackendDataPtr handle = tensor->CurrentDataHandle();
684 if (handle == nullptr && config.force_ltc_data) {
685 handle = tensor_data_vec[i];
686 // Note: We are not using SetHandleData method here since that method
687 // resets the ir_value. We have already done the resetting as part
688 // of ExtractIRAndPrepareTensorData to overlap with previous execution.
689 tensor->data()->handle = handle;
690 tensor->data()->tensor_data = c10::nullopt;
691 }
692 tensors_data.emplace_back(std::move(handle));
693 }
694 return tensors_data;
695}
696
697LazyGraphExecutor::PostOrderData LazyGraphExecutor::RunPostOrder(
698 const std::vector<Value>& ir_values,
699 SyncTensorCollection* coll) {
700 std::vector<const Node*> roots;
701 roots.reserve(ir_values.size());
702 for (const auto& ir_value : ir_values) {
703 roots.push_back(ir_value.node.get());
704 }
705 PostOrderData po_data;
706 po_data.post_order = Util::ComputePostOrder(roots, &po_data.emission_map);
707 std::unordered_map<BackendData::Handle, size_t> data_handles;
708 for (auto node : po_data.post_order) {
709 const auto backend_data = getBackend()->GetComputationDataFromNode(node);
710 if (backend_data) {
711 /* Acceptable race condition: HasValue may return false. This is OK
712 * since the conditional barrier is a performance optimization. */
713 if (!backend_data->HasValue()) {
714 TensorCollectionBarrier(coll);
715 }
716 BackendData::Handle handle = backend_data->GetHandle();
717 auto it = data_handles.find(handle);
718 if (it != data_handles.end()) {
719 po_data.parameter_sequence.push_back(it->second);
720 } else {
721 po_data.parameter_sequence.push_back(po_data.parameters_data.size());
722 data_handles[handle] = po_data.parameters_data.size();
723 po_data.parameters_data.push_back(backend_data);
724 }
725 }
726 }
727 return po_data;
728}
729
730std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::TryRunCachedSync(
731 std::vector<LazyTensorPtr>* tensors,
732 SyncTensorCollection* coll,
733 PostOrderData* po_data,
734 const std::vector<BackendDataPtr>& tensor_data_vec) {
735 ComputationCache::TypePtr cached_computation =
736 LookupCachedCompile(coll->hash);
737 if (cached_computation == nullptr) {
738 return nullptr;
739 }
740 if (GRAPH_DUMP_ENABLED) {
741 auto* comp = cached_computation->computation.get();
742 LOG(ERROR) << "Run a cached graph: " << comp->to_string() << std::endl;
743 }
744 TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", po_data->post_order.size());
745 VLOG(5) << "TensorsGraphSize=" << po_data->post_order.size();
746
747 return ScheduleSyncTensorsGraph(
748 tensors,
749 coll,
750 std::move(po_data->parameters_data),
751 std::move(cached_computation),
752 tensor_data_vec);
753}
754
755LazyGraphExecutor::CompilationResult LazyGraphExecutor::Compile(
756 const std::vector<LazyTensorPtr>& tensors,
757 c10::ArrayRef<std::string> devices,
758 const SyncTensorCollection& coll,
759 PostOrderData* po_data,
760 const std::vector<Value>& ir_values) {
761 auto lowering_ctx = LoweringContext::Create(
762 "SyncTensorsGraph",
763 coll.device,
764 po_data->post_order,
765 std::move(po_data->emission_map));
766 for (const auto& ir_value : ir_values) {
767 lowering_ctx->AddResult(ir_value);
768 }
769
770 ComputationPtr computation = lowering_ctx->Build();
771 // If force_ltc_data is true it means that we did a proper sync and are
772 // inside a mark step. If GetTensors was called, force_ltc_data will
773 // be false meaning we are prematurely evaluating some value.
774 computation->in_mark_step = coll.config.force_ltc_data;
775
776 VLOG(3) << "Compiling IR graph hash " << HashToString(coll.hash)
777 << " on device " << coll.device << " ...";
778 std::vector<ComputationPtr> computations =
779 getBackend()->Compile({computation});
780 VLOG(3) << "Compiling IR graph hash " << HashToString(coll.hash)
781 << " on device " << coll.device << " done!";
782 if (computation) {
783 // TODO(whc) should computation be allowed null here? (because it is in one
784 // case)
785 TORCH_CHECK(
786 computation->parameters_size() == po_data->parameters_data.size());
787 }
788
789 return {
790 /*device=*/coll.device,
791 /*emitted_nodes=*/lowering_ctx->GetEmittedNodeCount(),
792 /*computation=*/std::move(computations.front()),
793 /*parameters_data=*/std::move(po_data->parameters_data)};
794}
795
796LazyGraphExecutor::ComputationCache* LazyGraphExecutor::GetComputationCache() {
797 static ComputationCache* cache =
798 new ComputationCache(FLAGS_torch_lazy_compilation_cache_size);
799 return cache;
800}
801
802LazyGraphExecutor::ComputationCache::TypePtr LazyGraphExecutor::
803 LookupCachedCompile(const hash_t& hash) {
804 ComputationCache::TypePtr cached_computation =
805 GetComputationCache()->Get(hash);
806 if (cached_computation == nullptr) {
807 TORCH_LAZY_COUNTER("UncachedCompile", 1);
808 return nullptr;
809 }
810 TORCH_LAZY_COUNTER("CachedCompile", 1);
811 return cached_computation;
812}
813
814#if defined(_MSC_VER)
815#include <BaseTsd.h>
816typedef SSIZE_T ssize_t;
817#endif
818
819std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
820 SyncTensorsGraphInternal(
821 std::vector<LazyTensorPtr>* tensors,
822 c10::ArrayRef<std::string> devices,
823 const SyncTensorsConfig& config) {
824 SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
825 if (coll.indices.empty()) {
826 /* Enure previous execution is complete before exiting this
827 * function */
828 TensorCollectionBarrier(&coll);
829 return nullptr;
830 }
831 DebugUtil::SaveTensorsGraphInfo(
832 "ScheduleSyncTensorsGraph", *tensors, &coll.indices);
833 std::vector<Value> ir_values;
834 std::vector<BackendDataPtr> tensor_data_vec;
835 ExtractIRAndPrepareTensorData(
836 tensors, coll.config, coll.indices, ir_values, tensor_data_vec);
837 PostOrderData po_data = RunPostOrder(ir_values, &coll);
838 coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence));
839 VLOG(4) << "Parameter sequence graph hash " << HashToString(coll.hash);
840 std::shared_ptr<Async> async =
841 TryRunCachedSync(tensors, &coll, &po_data, tensor_data_vec);
842 if (async != nullptr) {
843 return async;
844 }
845
846 CompilationResult compile_result =
847 Compile(*tensors, devices, coll, &po_data, ir_values);
848 if (GRAPH_DUMP_ENABLED) {
849 auto* comp = compile_result.computation.get();
850 LOG(ERROR) << "Add a cached computation with hash " << coll.hash
851 << std::endl;
852 LOG(ERROR) << "Add a graph to cache: " << comp->to_string() << std::endl;
853 }
854
855 TORCH_LAZY_VALUE_METRIC("TensorsGraphSize", compile_result.emitted_nodes);
856 VLOG(5) << "TensorsGraphSize=" << compile_result.emitted_nodes;
857
858 auto cached_computation = std::make_shared<CachedComputation>(
859 std::move(compile_result.computation));
860 GetComputationCache()->Add(coll.hash, cached_computation);
861
862 return ScheduleSyncTensorsGraph(
863 tensors,
864 &coll,
865 std::move(compile_result.parameters_data),
866 std::move(cached_computation),
867 tensor_data_vec);
868}
869
870std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
871 ScheduleSyncTensorsGraph(
872 SyncTensorCollection* coll,
873 std::vector<BackendDataPtr> parameters_data,
874 std::vector<BackendDataPtr> tensors_data,
875 ComputationCache::TypePtr cached_computation) {
876 TensorCollectionBarrier(coll);
877 std::shared_ptr<Async> async = std::make_shared<Async>(
878 coll,
879 std::move(parameters_data),
880 std::move(tensors_data),
881 std::move(cached_computation));
882
883 auto syncfn = [async, hash = coll->hash]() {
884 try {
885 VLOG(3) << "Executing IR graph hash " << HashToString(hash)
886 << " on device " << async->device << " ...";
887 auto results = getBackend()->ExecuteComputation(
888 async->cached_computation->computation,
889 async->parameters_data,
890 async->device);
891 VLOG(3) << "Executing IR graph hash " << HashToString(hash)
892 << " on device " << async->device << " done!";
893
894 TORCH_CHECK(
895 async->tensors_data.size() == results.size(),
896 "Expected number of outputs does not match TorchScript Stack size: ",
897 async->tensors_data.size(),
898 " != ",
899 results.size());
900
901 for (const auto i : c10::irange(results.size())) {
902 if (async->tensors_data[i] != nullptr) {
903 async->tensors_data[i]->Assign(*results[i]);
904 } else {
905 async->tensors_data[i] = std::move(results[i]);
906 }
907 }
908 } catch (...) {
909 // There are two paths of discovery of an exception happening on an
910 // asynchronous task. One happens if the creator of the asynchronous task
911 // explicitly waits for completion, in which case the exception will be
912 // thrown from the Wait() API. Re-throwing the exception below makes sure
913 // this will be captured by the completer function created below, and
914 // surfaced by the Wait() API. But we also need to surface the exception
915 // even in case the caller does not wait, and that is accomplished by
916 // setting the unlockers status. In that case the exception will be
917 // surfaced when the user tries to acquire the device locks the next time.
918 for (auto& unlocker : async->unlocker) {
919 unlocker.SetStatus(std::current_exception());
920 }
921 throw;
922 }
923 };
924
925 if (FLAGS_torch_lazy_use_thread_pool) {
926 ScheduleIoClosure(async->mwait.Completer(std::move(syncfn)));
927 } else {
928 syncfn();
929 }
930 return async;
931}
932
933std::shared_ptr<LazyGraphExecutor::Async> LazyGraphExecutor::
934 ScheduleSyncTensorsGraph(
935 std::vector<LazyTensorPtr>* tensors,
936 SyncTensorCollection* coll,
937 std::vector<BackendDataPtr> parameters_data,
938 ComputationCache::TypePtr cached_computation,
939 const std::vector<BackendDataPtr>& tensor_data_vec) {
940 auto tensors_data =
941 SetTensorData(tensors, coll->config, coll->indices, tensor_data_vec);
942 return ScheduleSyncTensorsGraph(
943 coll,
944 std::move(parameters_data),
945 std::move(tensors_data),
946 std::move(cached_computation));
947}
948
949std::vector<at::Tensor> LazyGraphExecutor::GetTensorsFused(
950 std::vector<LazyTensorPtr>* tensors) {
951 SyncTensorsConfig config;
952 config.force_ltc_data = false;
953 auto async = SyncTensorsGraphInternal(tensors, {}, config);
954 if (FLAGS_torch_lazy_use_thread_pool && async != nullptr) {
955 async->mwait.Wait();
956 }
957 std::vector<BackendDataPtr> tensors_data = GatherTensorsData(
958 *tensors,
959 async != nullptr ? async->indices : c10::ArrayRef<size_t>(),
960 async != nullptr ? async->tensors_data : c10::ArrayRef<BackendDataPtr>());
961 return FetchTensors(
962 tensors, tensors_data, async != nullptr ? &async->indices : nullptr);
963}
964
965// This gets tensors from the backend
966// for TS backend, we'd ideally just cut through these layers and
967// not need to copy the tensor, just move it
968
969// for XLA backend, a copy is going to have to happen,
970
971// could we replace the 'Data' object with an at::Tensor, which is 'undefined'
972// unless a backend attaches a buffer to it? That way we can have a
973// 'PopulateTensor' method on backend, which can either attach an existing
974// tensor buffer to the wrapper, or copy data?
975std::vector<at::Tensor> LazyGraphExecutor::FetchTensors(
976 std::vector<LazyTensorPtr>* tensors,
977 c10::ArrayRef<BackendDataPtr> tensors_data,
978 const std::vector<size_t>* indices) {
979 std::vector<at::Tensor> results;
980 size_t literals_index = 0;
981 size_t sync_index = 0;
982 results.reserve(tensors->size());
983 for (const auto i : c10::irange(tensors->size())) {
984 if (indices != nullptr && sync_index < indices->size() &&
985 i == (*indices)[sync_index]) {
986 results.push_back(getBackend()->MakeTensorFromComputationData(
987 tensors_data[literals_index], (*tensors)[i]->dtype()));
988 ++literals_index;
989 ++sync_index;
990 } else {
991 c10::optional<at::Tensor> tensor_data =
992 (*tensors)[i]->CurrentTensorData();
993 if (tensor_data) {
994 results.push_back(*tensor_data);
995 } else {
996 TORCH_CHECK(literals_index < tensors_data.size());
997 results.push_back(getBackend()->MakeTensorFromComputationData(
998 tensors_data[literals_index], (*tensors)[i]->dtype()));
999 ++literals_index;
1000 }
1001 }
1002 }
1003 return results;
1004}
1005
1006std::vector<BackendDataPtr> LazyGraphExecutor::GatherTensorsData(
1007 const std::vector<LazyTensorPtr>& tensors,
1008 c10::ArrayRef<size_t> indices,
1009 c10::ArrayRef<BackendDataPtr> tensors_data) {
1010 std::vector<BackendDataPtr> result_tensors_data;
1011 std::unordered_map<int64_t, size_t> uid_index_map;
1012 size_t indices_index = 0;
1013 for (const auto i : c10::irange(tensors.size())) {
1014 int64_t tensor_id = tensors[i]->GetUniqueId();
1015 auto it = uid_index_map.find(tensor_id);
1016 if (it != uid_index_map.end()) {
1017 // Current tensor is a duplicate of a previously processed tensor that had
1018 // an IR Node to sync. Get the data from the tensor_data_map.
1019 result_tensors_data.push_back(result_tensors_data[it->second]);
1020 } else if (indices_index < indices.size() && i == indices[indices_index]) {
1021 // If we are at the current index (it means that the tensor at index
1022 // 'i' had an IR node to sync), use the data held within the Async
1023 // object.
1024 uid_index_map.emplace(tensor_id, result_tensors_data.size());
1025 result_tensors_data.push_back(tensors_data[indices_index]);
1026 ++indices_index;
1027 } else if (!tensors[i]->CurrentTensorData()) {
1028 BackendDataPtr handle = tensors[i]->CurrentDataHandle();
1029 TORCH_CHECK(handle != nullptr);
1030 result_tensors_data.push_back(std::move(handle));
1031 }
1032 }
1033 return result_tensors_data;
1034}
1035
1036void LazyGraphExecutor::TensorCollectionBarrier(SyncTensorCollection* coll) {
1037 if (coll) {
1038 static const std::string invalid_device(
1039 "Unknown0"); /* Temp solution to idetify unassigned devices */
1040 if (coll->device.toString() == invalid_device || !coll->unlocker.empty()) {
1041 return;
1042 }
1043 VLOG(4) << "Waiting on device barrier for device " << coll->device
1044 << " ...";
1045 {
1046 TORCH_LAZY_TIMED("DeviceLockWait");
1047 coll->unlocker = DeviceLockerArena::Get()->LockDevices({coll->device});
1048 }
1049 VLOG(4) << "Waiting on device barrier for device " << coll->device
1050 << " done!";
1051 }
1052}
1053
1054hash_t LazyGraphExecutor::GetGraphHash(
1055 const std::vector<LazyTensorPtr>& tensors) {
1056 SyncTensorsConfig config;
1057 config.sync_ltc_data = false;
1058
1059 auto coll = CollectSyncTensors(tensors, config);
1060 std::vector<Value> ir_values;
1061 for (auto index : coll.indices) {
1062 Value ir_value = tensors[index]->CurrentIrValue();
1063 ir_values.push_back(ir_value);
1064 }
1065 auto po_data = RunPostOrder(ir_values, &coll);
1066 coll.hash = HashCombine(coll.hash, Hash(po_data.parameter_sequence));
1067 return coll.hash;
1068}
1069
1070} // namespace lazy
1071} // namespace torch
1072