1#pragma once
2
3#include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h>
4#include <ATen/core/function.h>
5#include <c10/util/Metaprogramming.h>
6#include <c10/util/TypeTraits.h>
7#include <c10/util/irange.h>
8
9namespace torch {
10
11namespace detail {
12/**
13 * In the Facebook internal build (using BUCK), this macro is enabled by
14 * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
15 * binary.
16 */
17#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
18TORCH_API void record_custom_class(std::string name);
19
20/**
21 * Record an instance of a custom class being loaded
22 * grab portion of string after final '.' from qualified name
23 * as this seemingly aligns with how users name their custom classes
24 * example: __torch__.torch.classes.xnnpack.Conv2dOpContext
25 */
26#define RECORD_CUSTOM_CLASS(NAME) \
27 auto name = std::string(NAME); \
28 detail::record_custom_class(name.substr(name.find_last_of(".") + 1));
29#else
30#define RECORD_CUSTOM_CLASS(NAME)
31#endif
32} // namespace detail
33
34/// This struct is used to represent default values for arguments
35/// when registering methods for custom classes.
36/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
37/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
38struct arg {
39 // Static method for representing a default value of None. This is meant to
40 // be used like so:
41 // torch::arg("name") = torch::arg::none
42 // and is identical to:
43 // torch::arg("name") = IValue()
44 static c10::IValue none() {
45 return c10::IValue();
46 }
47
48 // Explicit constructor.
49 explicit arg(std::string name)
50 : name_(std::move(name)), value_(c10::nullopt) {}
51 // Assignment operator. This enables the pybind-like syntax of
52 // torch::arg("name") = value.
53 arg& operator=(const c10::IValue& rhs) {
54 value_ = rhs;
55 return *this;
56 }
57
58 // The name of the argument. This is copied to the schema; argument
59 // names cannot be extracted from the C++ declaration.
60 std::string name_;
61 // IValue's default constructor makes it None, which is not distinguishable
62 // from an actual, user-provided default value that is None. This boolean
63 // helps distinguish between the two cases.
64 c10::optional<c10::IValue> value_;
65};
66
67namespace detail {
68
69// Argument type utilities
70template <class R, class...>
71struct types {
72 using type = types;
73};
74
75template <typename Method>
76struct WrapMethod;
77
78template <typename R, typename CurrClass, typename... Args>
79struct WrapMethod<R (CurrClass::*)(Args...)> {
80 WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
81
82 R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
83 return c10::guts::invoke(m, *cur, args...);
84 }
85
86 R (CurrClass::*m)(Args...);
87};
88
89template <typename R, typename CurrClass, typename... Args>
90struct WrapMethod<R (CurrClass::*)(Args...) const> {
91 WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
92
93 R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
94 return c10::guts::invoke(m, *cur, args...);
95 }
96
97 R (CurrClass::*m)(Args...) const;
98};
99
100// Adapter for different callable types
101template <
102 typename CurClass,
103 typename Func,
104 std::enable_if_t<
105 std::is_member_function_pointer<std::decay_t<Func>>::value,
106 bool> = false>
107WrapMethod<Func> wrap_func(Func f) {
108 return WrapMethod<Func>(std::move(f));
109}
110
111template <
112 typename CurClass,
113 typename Func,
114 std::enable_if_t<
115 !std::is_member_function_pointer<std::decay_t<Func>>::value,
116 bool> = false>
117Func wrap_func(Func f) {
118 return f;
119}
120
121template <
122 class Functor,
123 bool AllowDeprecatedTypes,
124 size_t... ivalue_arg_indices>
125typename c10::guts::infer_function_traits_t<Functor>::return_type
126call_torchbind_method_from_stack(
127 Functor& functor,
128 jit::Stack& stack,
129 std::index_sequence<ivalue_arg_indices...>) {
130 (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
131 // be unused and we have to silence the compiler warning.
132
133 constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices);
134
135 using IValueArgTypes =
136 typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
137 // TODO We shouldn't use c10::impl stuff directly here. We should use the
138 // KernelFunction API instead.
139 return (functor)(c10::impl::ivalue_to_arg<
140 typename c10::impl::decay_if_not_tensor<
141 c10::guts::typelist::
142 element_t<ivalue_arg_indices, IValueArgTypes>>::type,
143 AllowDeprecatedTypes>::
144 call(torch::jit::peek(
145 stack, ivalue_arg_indices, num_ivalue_args))...);
146}
147
148template <class Functor, bool AllowDeprecatedTypes>
149typename c10::guts::infer_function_traits_t<Functor>::return_type
150call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) {
151 constexpr size_t num_ivalue_args =
152 c10::guts::infer_function_traits_t<Functor>::number_of_parameters;
153 return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>(
154 functor, stack, std::make_index_sequence<num_ivalue_args>());
155}
156
157template <class RetType, class Func>
158struct BoxedProxy;
159
160template <class RetType, class Func>
161struct BoxedProxy {
162 void operator()(jit::Stack& stack, Func& func) {
163 auto retval = call_torchbind_method_from_stack<Func, false>(func, stack);
164 constexpr size_t num_ivalue_args =
165 c10::guts::infer_function_traits_t<Func>::number_of_parameters;
166 torch::jit::drop(stack, num_ivalue_args);
167 stack.emplace_back(c10::ivalue::from(std::move(retval)));
168 }
169};
170
171template <class Func>
172struct BoxedProxy<void, Func> {
173 void operator()(jit::Stack& stack, Func& func) {
174 call_torchbind_method_from_stack<Func, false>(func, stack);
175 constexpr size_t num_ivalue_args =
176 c10::guts::infer_function_traits_t<Func>::number_of_parameters;
177 torch::jit::drop(stack, num_ivalue_args);
178 stack.emplace_back();
179 }
180};
181
182inline bool validIdent(size_t i, char n) {
183 return isalpha(n) || n == '_' || (i > 0 && isdigit(n));
184}
185
186inline void checkValidIdent(const std::string& str, const char* type) {
187 for (const auto i : c10::irange(str.size())) {
188 TORCH_CHECK(
189 validIdent(i, str[i]),
190 type,
191 " must be a valid Python/C++ identifier."
192 " Character '",
193 str[i],
194 "' at index ",
195 i,
196 " is illegal.");
197 }
198}
199
200class TORCH_API class_base {
201 protected:
202 explicit class_base(
203 const std::string& namespaceName,
204 const std::string& className,
205 std::string doc_string,
206 const std::type_info& intrusivePtrClassTypeid,
207 const std::type_info& taggedCapsuleClass);
208
209 static c10::FunctionSchema withNewArguments(
210 const c10::FunctionSchema& schema,
211 std::initializer_list<arg> default_args);
212 std::string qualClassName;
213 at::ClassTypePtr classTypePtr;
214};
215
216} // namespace detail
217
218TORCH_API void registerCustomClass(at::ClassTypePtr class_type);
219TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method);
220
221// Given a qualified name (e.g. __torch__.torch.classes.Foo), return
222// the ClassType pointer to the Type that describes that custom class,
223// or nullptr if no class by that name was found.
224TORCH_API at::ClassTypePtr getCustomClass(const std::string& name);
225
226// Given an IValue, return true if the object contained in that IValue
227// is a custom C++ class, otherwise return false.
228TORCH_API bool isCustomClass(const c10::IValue& v);
229
230// This API is for testing purposes ONLY. It should not be used in
231// any load-bearing code.
232TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck();
233
234namespace jit {
235using ::torch::registerCustomClass;
236using ::torch::registerCustomClassMethod;
237} // namespace jit
238
239} // namespace torch
240