1// Copyright © 2022 Apple Inc.
2
3#pragma once
4
5#include <c10/core/Allocator.h>
6#include <ATen/core/Generator.h>
7#include <c10/util/Exception.h>
8#include <c10/util/Registry.h>
9
10#include <cstddef>
11#include <functional>
12
13namespace at {
14class Context;
15}
16
17namespace at {
18
19struct TORCH_API MPSHooksInterface {
20 virtual ~MPSHooksInterface() = default;
21
22 // Initialize the MPS library state
23 virtual void initMPS() const {
24 AT_ERROR("Cannot initialize MPS without MPS backend.");
25 }
26
27 virtual bool hasMPS() const {
28 return false;
29 }
30
31 virtual const Generator& getDefaultMPSGenerator() const {
32 AT_ERROR("Cannot get default MPS generator without MPS backend.");
33 }
34
35 virtual Allocator* getMPSDeviceAllocator() const {
36 AT_ERROR("MPSDeviceAllocator requires MPS.");
37 }
38};
39
40struct TORCH_API MPSHooksArgs {};
41
42C10_DECLARE_REGISTRY(MPSHooksRegistry, MPSHooksInterface, MPSHooksArgs);
43#define REGISTER_MPS_HOOKS(clsname) \
44 C10_REGISTER_CLASS(MPSHooksRegistry, clsname, clsname)
45
46namespace detail {
47TORCH_API const MPSHooksInterface& getMPSHooks();
48
49} // namespace detail
50} // namespace at
51