1 | // Copyright © 2022 Apple Inc. |
2 | |
3 | #pragma once |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <ATen/ATen.h> |
7 | |
8 | |
9 | #ifdef __OBJC__ |
10 | #include <Foundation/Foundation.h> |
11 | #include <Metal/Metal.h> |
12 | #include <MetalPerformanceShaders/MetalPerformanceShaders.h> |
13 | typedef id<MTLDevice> MTLDevice_t; |
14 | typedef id<MTLLibrary> MTLLibrary_t; |
15 | typedef id<MTLFunction> MTLFunction_t; |
16 | typedef MTLFunctionConstantValues* MTLFunctionConstantValues_t; |
17 | #else |
18 | typedef void* MTLDevice; |
19 | typedef void* MTLDevice_t; |
20 | typedef void* MTLLibrary_t; |
21 | typedef void* MTLFunction_t; |
22 | typedef void* MTLFunctionConstantValues_t; |
23 | #endif |
24 | |
25 | using namespace std; |
26 | |
27 | namespace at { |
28 | namespace mps { |
29 | |
30 | // Helper enum to check if a MPSGraph op is supported in a given macOS version |
31 | enum class MacOSVersion : uint32_t { |
32 | MACOS_VER_13_0_PLUS = 0, |
33 | MACOS_VER_13_1_PLUS, |
34 | MACOS_VER_13_2_PLUS, |
35 | MACOS_VER_13_3_PLUS, |
36 | }; |
37 | |
38 | //----------------------------------------------------------------- |
39 | // MPSDevice |
40 | // |
41 | // MPSDevice is a singleton class that returns the default device |
42 | //----------------------------------------------------------------- |
43 | |
44 | class TORCH_API MPSDevice { |
45 | public: |
46 | /** |
47 | * MPSDevice should not be cloneable. |
48 | */ |
49 | MPSDevice(MPSDevice& other) = delete; |
50 | /** |
51 | * MPSDevice should not be assignable. |
52 | */ |
53 | void operator=(const MPSDevice&) = delete; |
54 | /** |
55 | * Gets single instance of the Device. |
56 | */ |
57 | static MPSDevice* getInstance(); |
58 | /** |
59 | * Returns the single device. |
60 | */ |
61 | MTLDevice_t device() { |
62 | return _mtl_device; |
63 | } |
64 | /** |
65 | * Returns whether running on Ventura or newer |
66 | */ |
67 | bool isMacOS13Plus(MacOSVersion version) const; |
68 | |
69 | MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues); |
70 | |
71 | ~MPSDevice(); |
72 | |
73 | private: |
74 | static MPSDevice* _device; |
75 | MTLDevice_t _mtl_device; |
76 | MTLLibrary_t _mtl_indexing_library; |
77 | MPSDevice(); |
78 | }; |
79 | |
80 | TORCH_API bool is_available(); |
81 | TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS); |
82 | |
83 | TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false); |
84 | |
85 | } // namespace mps |
86 | } // namespace at |
87 | |