1#pragma once
2
3#include <c10/core/Allocator.h>
4#include <ATen/core/Generator.h>
5#include <c10/util/Exception.h>
6
7#include <c10/util/Registry.h>
8
9#include <cstddef>
10#include <functional>
11#include <memory>
12
13namespace at {
14class Context;
15}
16
17// NB: Class must live in `at` due to limitations of Registry.h.
18namespace at {
19
20// The HIPHooksInterface is an omnibus interface for any HIP functionality
21// which we may want to call into from CPU code (and thus must be dynamically
22// dispatched, to allow for separate compilation of HIP code). See
23// CUDAHooksInterface for more detailed motivation.
24struct TORCH_API HIPHooksInterface {
25 // This should never actually be implemented, but it is used to
26 // squelch -Werror=non-virtual-dtor
27 virtual ~HIPHooksInterface() = default;
28
29 // Initialize the HIP library state
30 virtual void initHIP() const {
31 AT_ERROR("Cannot initialize HIP without ATen_hip library.");
32 }
33
34 virtual std::unique_ptr<c10::GeneratorImpl> initHIPGenerator(Context*) const {
35 AT_ERROR("Cannot initialize HIP generator without ATen_hip library.");
36 }
37
38 virtual bool hasHIP() const {
39 return false;
40 }
41
42 virtual int64_t current_device() const {
43 return -1;
44 }
45
46 virtual Allocator* getPinnedMemoryAllocator() const {
47 AT_ERROR("Pinned memory requires HIP.");
48 }
49
50 virtual void registerHIPTypes(Context*) const {
51 AT_ERROR("Cannot registerHIPTypes() without ATen_hip library.");
52 }
53
54 virtual int getNumGPUs() const {
55 return 0;
56 }
57};
58
59// NB: dummy argument to suppress "ISO C++11 requires at least one argument
60// for the "..." in a variadic macro"
61struct TORCH_API HIPHooksArgs {};
62
63C10_DECLARE_REGISTRY(HIPHooksRegistry, HIPHooksInterface, HIPHooksArgs);
64#define REGISTER_HIP_HOOKS(clsname) \
65 C10_REGISTER_CLASS(HIPHooksRegistry, clsname, clsname)
66
67namespace detail {
68TORCH_API const HIPHooksInterface& getHIPHooks();
69
70} // namespace detail
71} // namespace at
72