1 | #pragma once |
2 | |
3 | #include <ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h> |
4 | #include <ATen/core/function.h> |
5 | #include <c10/util/Metaprogramming.h> |
6 | #include <c10/util/TypeTraits.h> |
7 | #include <c10/util/irange.h> |
8 | |
9 | namespace torch { |
10 | |
11 | namespace detail { |
12 | /** |
13 | * In the Facebook internal build (using BUCK), this macro is enabled by |
14 | * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer |
15 | * binary. |
16 | */ |
17 | #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE |
18 | TORCH_API void record_custom_class(std::string name); |
19 | |
20 | /** |
21 | * Record an instance of a custom class being loaded |
22 | * grab portion of string after final '.' from qualified name |
23 | * as this seemingly aligns with how users name their custom classes |
24 | * example: __torch__.torch.classes.xnnpack.Conv2dOpContext |
25 | */ |
26 | #define RECORD_CUSTOM_CLASS(NAME) \ |
27 | auto name = std::string(NAME); \ |
28 | detail::record_custom_class(name.substr(name.find_last_of(".") + 1)); |
29 | #else |
30 | #define RECORD_CUSTOM_CLASS(NAME) |
31 | #endif |
32 | } // namespace detail |
33 | |
34 | /// This struct is used to represent default values for arguments |
35 | /// when registering methods for custom classes. |
36 | /// static auto register_foo = torch::class_<Foo>("myclasses", "Foo") |
37 | /// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name}); |
38 | struct arg { |
39 | // Static method for representing a default value of None. This is meant to |
40 | // be used like so: |
41 | // torch::arg("name") = torch::arg::none |
42 | // and is identical to: |
43 | // torch::arg("name") = IValue() |
44 | static c10::IValue none() { |
45 | return c10::IValue(); |
46 | } |
47 | |
48 | // Explicit constructor. |
49 | explicit arg(std::string name) |
50 | : name_(std::move(name)), value_(c10::nullopt) {} |
51 | // Assignment operator. This enables the pybind-like syntax of |
52 | // torch::arg("name") = value. |
53 | arg& operator=(const c10::IValue& rhs) { |
54 | value_ = rhs; |
55 | return *this; |
56 | } |
57 | |
58 | // The name of the argument. This is copied to the schema; argument |
59 | // names cannot be extracted from the C++ declaration. |
60 | std::string name_; |
61 | // IValue's default constructor makes it None, which is not distinguishable |
62 | // from an actual, user-provided default value that is None. This boolean |
63 | // helps distinguish between the two cases. |
64 | c10::optional<c10::IValue> value_; |
65 | }; |
66 | |
67 | namespace detail { |
68 | |
69 | // Argument type utilities |
70 | template <class R, class...> |
71 | struct types { |
72 | using type = types; |
73 | }; |
74 | |
75 | template <typename Method> |
76 | struct WrapMethod; |
77 | |
78 | template <typename R, typename CurrClass, typename... Args> |
79 | struct WrapMethod<R (CurrClass::*)(Args...)> { |
80 | WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {} |
81 | |
82 | R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { |
83 | return c10::guts::invoke(m, *cur, args...); |
84 | } |
85 | |
86 | R (CurrClass::*m)(Args...); |
87 | }; |
88 | |
89 | template <typename R, typename CurrClass, typename... Args> |
90 | struct WrapMethod<R (CurrClass::*)(Args...) const> { |
91 | WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {} |
92 | |
93 | R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) { |
94 | return c10::guts::invoke(m, *cur, args...); |
95 | } |
96 | |
97 | R (CurrClass::*m)(Args...) const; |
98 | }; |
99 | |
100 | // Adapter for different callable types |
101 | template < |
102 | typename CurClass, |
103 | typename Func, |
104 | std::enable_if_t< |
105 | std::is_member_function_pointer<std::decay_t<Func>>::value, |
106 | bool> = false> |
107 | WrapMethod<Func> wrap_func(Func f) { |
108 | return WrapMethod<Func>(std::move(f)); |
109 | } |
110 | |
111 | template < |
112 | typename CurClass, |
113 | typename Func, |
114 | std::enable_if_t< |
115 | !std::is_member_function_pointer<std::decay_t<Func>>::value, |
116 | bool> = false> |
117 | Func wrap_func(Func f) { |
118 | return f; |
119 | } |
120 | |
121 | template < |
122 | class Functor, |
123 | bool AllowDeprecatedTypes, |
124 | size_t... ivalue_arg_indices> |
125 | typename c10::guts::infer_function_traits_t<Functor>::return_type |
126 | call_torchbind_method_from_stack( |
127 | Functor& functor, |
128 | jit::Stack& stack, |
129 | std::index_sequence<ivalue_arg_indices...>) { |
130 | (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would |
131 | // be unused and we have to silence the compiler warning. |
132 | |
133 | constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); |
134 | |
135 | using IValueArgTypes = |
136 | typename c10::guts::infer_function_traits_t<Functor>::parameter_types; |
137 | // TODO We shouldn't use c10::impl stuff directly here. We should use the |
138 | // KernelFunction API instead. |
139 | return (functor)(c10::impl::ivalue_to_arg< |
140 | typename c10::impl::decay_if_not_tensor< |
141 | c10::guts::typelist:: |
142 | element_t<ivalue_arg_indices, IValueArgTypes>>::type, |
143 | AllowDeprecatedTypes>:: |
144 | call(torch::jit::peek( |
145 | stack, ivalue_arg_indices, num_ivalue_args))...); |
146 | } |
147 | |
148 | template <class Functor, bool AllowDeprecatedTypes> |
149 | typename c10::guts::infer_function_traits_t<Functor>::return_type |
150 | call_torchbind_method_from_stack(Functor& functor, jit::Stack& stack) { |
151 | constexpr size_t num_ivalue_args = |
152 | c10::guts::infer_function_traits_t<Functor>::number_of_parameters; |
153 | return call_torchbind_method_from_stack<Functor, AllowDeprecatedTypes>( |
154 | functor, stack, std::make_index_sequence<num_ivalue_args>()); |
155 | } |
156 | |
157 | template <class RetType, class Func> |
158 | struct BoxedProxy; |
159 | |
160 | template <class RetType, class Func> |
161 | struct BoxedProxy { |
162 | void operator()(jit::Stack& stack, Func& func) { |
163 | auto retval = call_torchbind_method_from_stack<Func, false>(func, stack); |
164 | constexpr size_t num_ivalue_args = |
165 | c10::guts::infer_function_traits_t<Func>::number_of_parameters; |
166 | torch::jit::drop(stack, num_ivalue_args); |
167 | stack.emplace_back(c10::ivalue::from(std::move(retval))); |
168 | } |
169 | }; |
170 | |
171 | template <class Func> |
172 | struct BoxedProxy<void, Func> { |
173 | void operator()(jit::Stack& stack, Func& func) { |
174 | call_torchbind_method_from_stack<Func, false>(func, stack); |
175 | constexpr size_t num_ivalue_args = |
176 | c10::guts::infer_function_traits_t<Func>::number_of_parameters; |
177 | torch::jit::drop(stack, num_ivalue_args); |
178 | stack.emplace_back(); |
179 | } |
180 | }; |
181 | |
182 | inline bool validIdent(size_t i, char n) { |
183 | return isalpha(n) || n == '_' || (i > 0 && isdigit(n)); |
184 | } |
185 | |
186 | inline void checkValidIdent(const std::string& str, const char* type) { |
187 | for (const auto i : c10::irange(str.size())) { |
188 | TORCH_CHECK( |
189 | validIdent(i, str[i]), |
190 | type, |
191 | " must be a valid Python/C++ identifier." |
192 | " Character '" , |
193 | str[i], |
194 | "' at index " , |
195 | i, |
196 | " is illegal." ); |
197 | } |
198 | } |
199 | |
200 | class TORCH_API class_base { |
201 | protected: |
202 | explicit class_base( |
203 | const std::string& namespaceName, |
204 | const std::string& className, |
205 | std::string doc_string, |
206 | const std::type_info& intrusivePtrClassTypeid, |
207 | const std::type_info& taggedCapsuleClass); |
208 | |
209 | static c10::FunctionSchema withNewArguments( |
210 | const c10::FunctionSchema& schema, |
211 | std::initializer_list<arg> default_args); |
212 | std::string qualClassName; |
213 | at::ClassTypePtr classTypePtr; |
214 | }; |
215 | |
216 | } // namespace detail |
217 | |
218 | TORCH_API void registerCustomClass(at::ClassTypePtr class_type); |
219 | TORCH_API void registerCustomClassMethod(std::unique_ptr<jit::Function> method); |
220 | |
221 | // Given a qualified name (e.g. __torch__.torch.classes.Foo), return |
222 | // the ClassType pointer to the Type that describes that custom class, |
223 | // or nullptr if no class by that name was found. |
224 | TORCH_API at::ClassTypePtr getCustomClass(const std::string& name); |
225 | |
226 | // Given an IValue, return true if the object contained in that IValue |
227 | // is a custom C++ class, otherwise return false. |
228 | TORCH_API bool isCustomClass(const c10::IValue& v); |
229 | |
230 | // This API is for testing purposes ONLY. It should not be used in |
231 | // any load-bearing code. |
232 | TORCH_API std::vector<c10::FunctionSchema> customClassSchemasForBCCheck(); |
233 | |
234 | namespace jit { |
235 | using ::torch::registerCustomClass; |
236 | using ::torch::registerCustomClassMethod; |
237 | } // namespace jit |
238 | |
239 | } // namespace torch |
240 | |