1 | // Copyright © 2022 Apple Inc. |
---|---|
2 | |
3 | #pragma once |
4 | |
5 | #include <c10/core/Allocator.h> |
6 | #include <ATen/core/Generator.h> |
7 | #include <c10/util/Exception.h> |
8 | #include <c10/util/Registry.h> |
9 | |
10 | #include <cstddef> |
11 | #include <functional> |
12 | |
13 | namespace at { |
14 | class Context; |
15 | } |
16 | |
17 | namespace at { |
18 | |
19 | struct TORCH_API MPSHooksInterface { |
20 | virtual ~MPSHooksInterface() = default; |
21 | |
22 | // Initialize the MPS library state |
23 | virtual void initMPS() const { |
24 | AT_ERROR("Cannot initialize MPS without MPS backend."); |
25 | } |
26 | |
27 | virtual bool hasMPS() const { |
28 | return false; |
29 | } |
30 | |
31 | virtual const Generator& getDefaultMPSGenerator() const { |
32 | AT_ERROR("Cannot get default MPS generator without MPS backend."); |
33 | } |
34 | |
35 | virtual Allocator* getMPSDeviceAllocator() const { |
36 | AT_ERROR("MPSDeviceAllocator requires MPS."); |
37 | } |
38 | }; |
39 | |
40 | struct TORCH_API MPSHooksArgs {}; |
41 | |
42 | C10_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs); |
43 | #define REGISTER_MPS_HOOKS(clsname) \ |
44 | C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname) |
45 | |
46 | namespace detail { |
47 | TORCH_API const MPSHooksInterface& getMPSHooks(); |
48 | |
49 | } // namespace detail |
50 | } // namespace at |
51 |