1 | #pragma once |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/Registry.h> |
5 | |
6 | constexpr const char* ORT_HELP = |
7 | " You need to 'import torch_ort' to use the 'ort' device in PyTorch. " |
8 | "The 'torch_ort' module is provided by the ONNX Runtime itself " |
9 | "(https://onnxruntime.ai)." ; |
10 | |
11 | // NB: Class must live in `at` due to limitations of Registry.h. |
12 | namespace at { |
13 | |
14 | struct TORCH_API ORTHooksInterface { |
15 | // This should never actually be implemented, but it is used to |
16 | // squelch -Werror=non-virtual-dtor |
17 | virtual ~ORTHooksInterface() = default; |
18 | |
19 | virtual std::string showConfig() const { |
20 | TORCH_CHECK(false, "Cannot query detailed ORT version information." , ORT_HELP); |
21 | } |
22 | }; |
23 | |
24 | // NB: dummy argument to suppress "ISO C++11 requires at least one argument |
25 | // for the "..." in a variadic macro" |
26 | struct TORCH_API ORTHooksArgs {}; |
27 | |
28 | C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs); |
29 | #define REGISTER_ORT_HOOKS(clsname) \ |
30 | C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname) |
31 | |
32 | namespace detail { |
33 | TORCH_API const ORTHooksInterface& getORTHooks(); |
34 | } // namespace detail |
35 | |
36 | } // namespace at |
37 | |