1 | #ifndef MetalContext_h |
2 | #define MetalContext_h |
3 | |
4 | #include <atomic> |
5 | |
6 | #include <ATen/Tensor.h> |
7 | |
8 | namespace at { |
9 | namespace metal { |
10 | |
11 | struct MetalInterface { |
12 | virtual ~MetalInterface() = default; |
13 | virtual bool is_metal_available() const = 0; |
14 | virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) |
15 | const = 0; |
16 | }; |
17 | |
18 | extern std::atomic<const MetalInterface*> g_metal_impl_registry; |
19 | |
20 | class MetalImplRegistrar { |
21 | public: |
22 | explicit MetalImplRegistrar(MetalInterface*); |
23 | }; |
24 | |
25 | at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src); |
26 | |
27 | } // namespace metal |
28 | |
29 | namespace native { |
30 | bool is_metal_available(); |
31 | } // namespace native |
32 | |
33 | } // namespace at |
34 | |
35 | #endif /* MetalContext_h */ |
36 | |