1#pragma once
2
3#include <c10/util/Exception.h>
4#include <c10/util/Registry.h>
5
6constexpr const char* ORT_HELP =
7 " You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
8 "The 'torch_ort' module is provided by the ONNX Runtime itself "
9 "(https://onnxruntime.ai).";
10
11// NB: Class must live in `at` due to limitations of Registry.h.
12namespace at {
13
14struct TORCH_API ORTHooksInterface {
15 // This should never actually be implemented, but it is used to
16 // squelch -Werror=non-virtual-dtor
17 virtual ~ORTHooksInterface() = default;
18
19 virtual std::string showConfig() const {
20 TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
21 }
22};
23
24// NB: dummy argument to suppress "ISO C++11 requires at least one argument
25// for the "..." in a variadic macro"
26struct TORCH_API ORTHooksArgs {};
27
28C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
29#define REGISTER_ORT_HOOKS(clsname) \
30 C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
31
32namespace detail {
33TORCH_API const ORTHooksInterface& getORTHooks();
34} // namespace detail
35
36} // namespace at
37