1 | #pragma once |
2 | |
3 | #include <c10/core/Allocator.h> |
4 | #include <ATen/core/Generator.h> |
5 | #include <c10/util/Exception.h> |
6 | |
7 | #include <c10/util/Registry.h> |
8 | |
9 | #include <cstddef> |
10 | #include <functional> |
11 | #include <memory> |
12 | |
13 | namespace at { |
14 | class Context; |
15 | } |
16 | |
17 | // NB: Class must live in `at` due to limitations of Registry.h. |
18 | namespace at { |
19 | |
20 | // The HIPHooksInterface is an omnibus interface for any HIP functionality |
21 | // which we may want to call into from CPU code (and thus must be dynamically |
22 | // dispatched, to allow for separate compilation of HIP code). See |
23 | // CUDAHooksInterface for more detailed motivation. |
24 | struct TORCH_API HIPHooksInterface { |
25 | // This should never actually be implemented, but it is used to |
26 | // squelch -Werror=non-virtual-dtor |
27 | virtual ~HIPHooksInterface() = default; |
28 | |
29 | // Initialize the HIP library state |
30 | virtual void initHIP() const { |
31 | AT_ERROR("Cannot initialize HIP without ATen_hip library." ); |
32 | } |
33 | |
34 | virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const { |
35 | AT_ERROR("Cannot initialize HIP generator without ATen_hip library." ); |
36 | } |
37 | |
38 | virtual bool hasHIP() const { |
39 | return false; |
40 | } |
41 | |
42 | virtual int64_t current_device() const { |
43 | return -1; |
44 | } |
45 | |
46 | virtual Allocator* getPinnedMemoryAllocator() const { |
47 | AT_ERROR("Pinned memory requires HIP." ); |
48 | } |
49 | |
50 | virtual void registerHIPTypes(Context*) const { |
51 | AT_ERROR("Cannot registerHIPTypes() without ATen_hip library." ); |
52 | } |
53 | |
54 | virtual int getNumGPUs() const { |
55 | return 0; |
56 | } |
57 | }; |
58 | |
59 | // NB: dummy argument to suppress "ISO C++11 requires at least one argument |
60 | // for the "..." in a variadic macro" |
61 | struct TORCH_API HIPHooksArgs {}; |
62 | |
63 | C10_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs); |
64 | #define REGISTER_HIP_HOOKS(clsname) \ |
65 | C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname) |
66 | |
67 | namespace detail { |
68 | TORCH_API const HIPHooksInterface& getHIPHooks(); |
69 | |
70 | } // namespace detail |
71 | } // namespace at |
72 | |