1 | #include <torch/csrc/lazy/backend/backend_interface.h> |
2 | #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> |
3 | |
4 | namespace torch { |
5 | namespace lazy { |
6 | |
7 | namespace { |
8 | std::atomic<const BackendImplInterface*> backend_impl_registry; |
9 | } // namespace |
10 | |
11 | bool hasBackend() { |
12 | return !!backend_impl_registry.load(); |
13 | } |
14 | |
15 | const BackendImplInterface* getBackend() { |
16 | auto* interface = backend_impl_registry.load(); |
17 | TORCH_CHECK(interface, "Lazy tensor backend not registered." ); |
18 | return interface; |
19 | } |
20 | |
21 | BackendRegistrar::BackendRegistrar( |
22 | const BackendImplInterface* backend_impl_interface) { |
23 | backend_impl_registry.store(backend_impl_interface); |
24 | } |
25 | |
26 | // Get IrBuilder from backend. Use TorchScriptIrBuilder by default |
27 | const IrBuilder* getIrBuilder() { |
28 | static const IrBuilder* builder = getBackend()->GetIrBuilder(); |
29 | return builder; |
30 | } |
31 | |
32 | at::Tensor MakeTensorFromComputationData( |
33 | const BackendDataPtr data, |
34 | c10::optional<at::ScalarType> logical_scalar_type) { |
35 | return getBackend()->MakeTensorFromComputationData(data, logical_scalar_type); |
36 | } |
37 | |
38 | std::unique_ptr<LoweringContext> LoweringContext::Create( |
39 | const std::string& name, |
40 | BackendDevice device, |
41 | c10::ArrayRef<const Node*> post_order, |
42 | Util::EmissionMap emit_status) { |
43 | return getBackend()->CreateLoweringContext( |
44 | name, std::move(device), post_order, emit_status); |
45 | } |
46 | |
47 | std::unique_ptr<LoweringContext> LoweringContext::Create( |
48 | const std::string& name, |
49 | BackendDevice device) { |
50 | return getBackend()->CreateLoweringContext(name, std::move(device)); |
51 | } |
52 | |
53 | } // namespace lazy |
54 | } // namespace torch |
55 | |