1 | // Copyright (c) Facebook, Inc. and its affiliates. |
2 | // All rights reserved. |
3 | // |
4 | // This source code is licensed under the BSD-style license found in the |
5 | // LICENSE file in the root directory of this source tree. |
6 | |
7 | #include <ATen/FunctionalTensorWrapper.h> |
8 | #include <ATen/WrapDimUtils.h> |
9 | #include <torch/python.h> |
10 | |
11 | #include <ATen/functorch/BatchRulesHelper.h> |
12 | #include <ATen/functorch/BatchedFallback.h> |
13 | #include <ATen/functorch/BatchedTensorImpl.h> |
14 | #include <ATen/functorch/DynamicLayer.h> |
15 | #include <ATen/functorch/Interpreter.h> |
16 | #include <ATen/functorch/LegacyVmapTransforms.h> |
17 | #include <ATen/functorch/PlumbingHelper.h> |
18 | #include <ATen/functorch/TensorWrapper.h> |
19 | #include <c10/core/AutogradState.h> |
20 | |
21 | // This file contains functorch's Python bindings. |
22 | |
23 | namespace torch { |
24 | namespace functorch { |
25 | namespace impl { |
26 | |
27 | using namespace at::functorch; |
28 | |
29 | static bool has_level(const Tensor& self, int64_t level) { |
30 | const auto* batched = maybeGetBatchedImpl(self); |
31 | if (!batched) { |
32 | return false; |
33 | } |
34 | return batched->level() >= level; |
35 | } |
36 | |
37 | Tensor _add_batch_dim(const Tensor& self, int64_t batch_dim, int64_t level) { |
38 | return addBatchDim(self, batch_dim, level); |
39 | } |
40 | |
41 | Tensor _wrap_functional_tensor(const Tensor& self, int64_t level) { |
42 | auto t = at::functionalization::impl::to_functional_tensor(self); |
43 | at::functionalization::impl::unsafeGetFunctionalWrapper(t)->set_level(level); |
44 | return t; |
45 | } |
46 | |
47 | void _assert_wrapped_functional( |
48 | const Tensor& unwrapped, |
49 | const Tensor& wrapped) { |
50 | TORCH_INTERNAL_ASSERT( |
51 | at::functionalization::impl::isFunctionalTensor(wrapped)); |
52 | TORCH_INTERNAL_ASSERT( |
53 | !at::functionalization::impl::isFunctionalTensor(unwrapped)); |
54 | auto wrapped_impl = |
55 | at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); |
56 | auto& wrapped_inner = wrapped_impl->value(); |
57 | TORCH_INTERNAL_ASSERT( |
58 | unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) |
59 | } |
60 | |
61 | void _propagate_functional_input_mutation( |
62 | const Tensor& unwrapped, |
63 | const Tensor& wrapped) { |
64 | TORCH_INTERNAL_ASSERT( |
65 | at::functionalization::impl::isFunctionalTensor(wrapped)); |
66 | TORCH_INTERNAL_ASSERT( |
67 | !at::functionalization::impl::isFunctionalTensor(unwrapped)); |
68 | auto wrapped_impl = |
69 | at::functionalization::impl::unsafeGetFunctionalWrapper(wrapped); |
70 | // Ensure that the input is up to date by committing any pending updates to |
71 | // the alias. |
72 | wrapped_impl->sync_(); |
73 | auto& wrapped_inner = wrapped_impl->value(); |
74 | // It would probably be more reasonable to check that the two tensors are |
75 | // aliased, but we can't do that unless we give BatchedTensorImpl a notion of |
76 | // storage. |
77 | if (unwrapped.unsafeGetTensorImpl() == wrapped_inner.unsafeGetTensorImpl()) { |
78 | } else { |
79 | if (unwrapped.sym_nbytes() != wrapped_inner.sym_nbytes()) { |
80 | // Functions might resize zero-sized inputs, which we need to reflect |
81 | // ehre. |
82 | unwrapped.resize__symint(wrapped_inner.sym_sizes()); |
83 | } |
84 | // If the input tensor's metadata was mutated, then use as_strided_() |
85 | // to propagate the metadata change. |
86 | if (unwrapped.sym_sizes() != wrapped_inner.sym_sizes()) { |
87 | unwrapped.as_strided__symint( |
88 | wrapped_inner.sym_sizes(), wrapped_inner.sym_strides()); |
89 | } |
90 | unwrapped.copy_(wrapped_inner); |
91 | } |
92 | } |
93 | |
94 | static std::pair<Tensor, int64_t> remove_existing_batch_dim( |
95 | const BatchedTensorImpl* batched, |
96 | int64_t level) { |
97 | TORCH_INTERNAL_ASSERT(batched->level() == level); |
98 | return std::make_pair(batched->value(), batched->bdim()); |
99 | } |
100 | |
101 | // Poor man's version of np.moveaxis. Moves the dimension at `dst` to `src` |
102 | // while preserving the order of other existing dimensions. |
103 | // We should probably add np.moveaxis (it is more general) to PyTorch. (#36048) |
104 | // When we do, replace the following with it. |
105 | static Tensor _movedim(const Tensor& self, int64_t src, int64_t dst) { |
106 | auto logical_dim = self.dim(); |
107 | src = at::maybe_wrap_dim(src, logical_dim); |
108 | dst = at::maybe_wrap_dim(dst, logical_dim); |
109 | if (src == dst) { |
110 | return self; |
111 | } |
112 | VmapDimVector permutation; |
113 | permutation.reserve(logical_dim); |
114 | for (int64_t dim = 0; dim < logical_dim; dim++) { |
115 | if (dim == src) { |
116 | continue; |
117 | } |
118 | permutation.push_back(dim); |
119 | } |
120 | permutation.insert(permutation.begin() + dst, src); |
121 | return self.permute(permutation); |
122 | } |
123 | |
124 | // Removes the batch dim with level `level` from `self`. If this causes the |
125 | // last batch dim to be removed from a BatchedTensor, then this returns a |
126 | // regular Tensor. |
127 | // |
128 | // If the `level` of the batch dim to remove does not exist in `self`, then we |
129 | // add the batch dim in. This can happen if `self` didn't interact with a tensor |
130 | // inside the vmap level, for example, |
131 | // self = torch.randn(3) |
132 | // y = torch.randn(5) |
133 | // out = vmap(lambda x: vmap(lambda y: x)(y))(self) |
134 | // assert out.shape == (3, 5) |
135 | // Inside the inner vmap, `x` is a BatchedTensor with a single batch dimension |
136 | // corresponding to the *outer* vmap level and it doesn't have any dimensions |
137 | // that correspond to the inner vmap level so we need to create one for the |
138 | // user. |
139 | // |
140 | // `out_dim` controls where we should put the batch dimension in the output |
141 | // tensor. |
142 | Tensor _remove_batch_dim( |
143 | const Tensor& self, |
144 | int64_t level, |
145 | int64_t batch_size, |
146 | int64_t out_dim) { |
147 | if (!has_level(self, level)) { |
148 | auto self_sizes = self.sizes(); |
149 | VmapDimVector expanded_sizes(self_sizes.begin(), self_sizes.end()); |
150 | expanded_sizes.insert(expanded_sizes.begin() + out_dim, batch_size); |
151 | auto result = self.expand(expanded_sizes); |
152 | return result; |
153 | } |
154 | |
155 | // Must be batched if has_level(self, /*any_level*/) |
156 | const auto* batched = maybeGetBatchedImpl(self); |
157 | TORCH_INTERNAL_ASSERT(batched != nullptr); |
158 | |
159 | Tensor self_without_bdim; |
160 | int64_t newly_exposed_logical_dim; |
161 | std::tie(self_without_bdim, newly_exposed_logical_dim) = |
162 | remove_existing_batch_dim(batched, level); |
163 | auto result = _movedim(self_without_bdim, newly_exposed_logical_dim, out_dim); |
164 | return result; |
165 | } |
166 | |
167 | Tensor _unwrap_functional_tensor(const Tensor& self, bool add_back_views) { |
168 | // We only ever call that after popping out of a functionalize() call, in |
169 | // which case the current tensors should always be wrapped in a |
170 | // FunctionalTensorWrapper. |
171 | TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(self)); |
172 | auto functional = |
173 | at::functionalization::impl::unsafeGetFunctionalWrapper(self); |
174 | |
175 | // when regenerating the (potentially mutated) input tensors, the |
176 | // functionalization pass regenerates them through a series of view_copy() op |
177 | // calls. Functorch wants to turn those back into view ops though. Ensure that |
178 | // the input is up to date by committing any pending updates to the alias. |
179 | at::functionalization::impl::FunctionalizationReapplyViewsGuard guard( |
180 | add_back_views); |
181 | bool any_updates = functional->apply_updates(); |
182 | if (any_updates) { |
183 | functional->regenerate_from_base(); |
184 | } |
185 | return functional->value(); |
186 | } |
187 | |
188 | Tensor _wrap_for_grad(const Tensor& self, int64_t level) { |
189 | // NB: different behavior inside?? |
190 | // return self; |
191 | // TORCH_INTERNAL_ASSERT(!maybeGetTensorWrapper(self)); |
192 | // TORCH_INTERNAL_ASSERT(self.has_storage()); |
193 | return makeTensorWrapper(self, level); |
194 | } |
195 | |
196 | Tensor _unwrap_for_grad(const Tensor& self, int64_t level) { |
197 | auto* result = maybeGetTensorWrapper(self); |
198 | if (!result) { |
199 | return self; |
200 | } |
201 | TORCH_INTERNAL_ASSERT(result->level().has_value()); |
202 | if (result->level() == level) { |
203 | return result->value(); |
204 | } |
205 | return self; |
206 | } |
207 | |
208 | int64_t dlevel(const Tensor& tensor) { |
209 | auto* wrapped = maybeGetTensorWrapper(tensor); |
210 | if (!wrapped) { |
211 | return 0; |
212 | } |
213 | if (!wrapped->is_alive()) { |
214 | return -1; |
215 | } |
216 | return wrapped->level().value(); |
217 | } |
218 | |
219 | bool dump_tensor(const Tensor& self) { |
220 | dumpTensorCout(self); |
221 | return true; |
222 | } |
223 | |
224 | RandomnessType get_randomness_enum(const std::string& randomness) { |
225 | if (randomness == "error" ) { |
226 | return RandomnessType::Error; |
227 | } else if (randomness == "same" ) { |
228 | return RandomnessType::Same; |
229 | } else if (randomness == "different" ) { |
230 | return RandomnessType::Different; |
231 | } else { |
232 | TORCH_CHECK( |
233 | false, "randomness argument must be error, same, or different." ); |
234 | } |
235 | } |
236 | |
237 | int64_t _grad_increment_nesting() { |
238 | // See NOTE [grad and vjp interaction with no_grad] |
239 | bool prev_grad_mode = c10::GradMode::is_enabled(); |
240 | return initAndPushDynamicLayer( |
241 | TransformType::Grad, c10::nullopt, c10::nullopt, prev_grad_mode); |
242 | } |
243 | |
244 | int64_t _grad_decrement_nesting() { |
245 | auto layer = popDynamicLayerAndDeleteMetadata(); |
246 | TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad); |
247 | return layer.layerId(); |
248 | } |
249 | |
250 | int64_t _jvp_increment_nesting() { |
251 | // See NOTE [grad and vjp interaction with no_grad] |
252 | bool prev_fwd_grad_mode = |
253 | c10::AutogradState::get_tls_state().get_fw_grad_mode(); |
254 | return initAndPushDynamicLayer( |
255 | TransformType::Jvp, |
256 | c10::nullopt, |
257 | c10::nullopt, |
258 | c10::nullopt, |
259 | prev_fwd_grad_mode); |
260 | } |
261 | |
262 | int64_t _jvp_decrement_nesting() { |
263 | auto layer = popDynamicLayerAndDeleteMetadata(); |
264 | TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp); |
265 | return layer.layerId(); |
266 | } |
267 | |
268 | int64_t _vmap_increment_nesting( |
269 | int64_t batch_size, |
270 | const std::string& randomness) { |
271 | return initAndPushDynamicLayer( |
272 | TransformType::Vmap, batch_size, get_randomness_enum(randomness)); |
273 | } |
274 | |
275 | int64_t _vmap_decrement_nesting() { |
276 | auto layer = popDynamicLayerAndDeleteMetadata(); |
277 | TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Vmap); |
278 | return layer.layerId(); |
279 | } |
280 | |
281 | int64_t _func_increment_nesting(bool reapply_views) { |
282 | return initAndPushDynamicLayer( |
283 | TransformType::Functionalize, |
284 | c10::nullopt, |
285 | c10::nullopt, |
286 | c10::nullopt, |
287 | c10::nullopt, |
288 | /*functionalize_add_back_views=*/reapply_views); |
289 | } |
290 | |
291 | int64_t _func_decrement_nesting() { |
292 | auto layer = popDynamicLayerAndDeleteMetadata(); |
293 | TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Functionalize); |
294 | return layer.layerId(); |
295 | } |
296 | |
297 | static bool is_batchedtensor(const Tensor& tensor) { |
298 | auto* batched = maybeGetBatchedImpl(tensor); |
299 | return batched != nullptr; |
300 | } |
301 | |
302 | static bool is_gradtrackingtensor(const Tensor& tensor) { |
303 | auto* wrapped = maybeGetTensorWrapper(tensor); |
304 | return wrapped != nullptr; |
305 | } |
306 | |
307 | static bool is_functionaltensor(const Tensor& tensor) { |
308 | return tensor.unsafeGetTensorImpl()->key_set().has( |
309 | c10::DispatchKey::Functionalize); |
310 | } |
311 | |
312 | static Tensor get_unwrapped(const Tensor& tensor) { |
313 | auto* batched = maybeGetBatchedImpl(tensor); |
314 | if (batched) { |
315 | return batched->value(); |
316 | } |
317 | auto* wrapped = maybeGetTensorWrapper(tensor); |
318 | if (wrapped) { |
319 | return wrapped->value(); |
320 | } |
321 | if (at::functionalization::impl::isFunctionalTensor(tensor)) { |
322 | auto* functional = |
323 | at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); |
324 | return functional->value(); |
325 | } |
326 | TORCH_CHECK(false, "No wrappers present!" ); |
327 | } |
328 | |
329 | static int64_t maybe_get_level(const Tensor& tensor) { |
330 | auto* batched = maybeGetBatchedImpl(tensor); |
331 | if (batched) { |
332 | return batched->level(); |
333 | } |
334 | auto* wrapped = maybeGetTensorWrapper(tensor); |
335 | if (wrapped) { |
336 | if (wrapped->level()) { |
337 | return *wrapped->level(); |
338 | } |
339 | // TODO: this is a weird special case... |
340 | return -2; |
341 | } |
342 | if (at::functionalization::impl::isFunctionalTensor(tensor)) { |
343 | auto* functional = |
344 | at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); |
345 | return functional->level(); |
346 | } |
347 | return -1; |
348 | } |
349 | |
350 | static int64_t maybe_get_bdim(const Tensor& tensor) { |
351 | auto* batched = maybeGetBatchedImpl(tensor); |
352 | if (batched) { |
353 | return batched->bdim(); |
354 | } |
355 | return -1; |
356 | } |
357 | |
358 | static int64_t currentLevel() { |
359 | auto maybe_layer = maybeCurrentDynamicLayer(); |
360 | TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); |
361 | int64_t current_level = maybe_layer->layerId(); |
362 | return current_level; |
363 | } |
364 | |
365 | static void tls_set_vmap_excluded(bool excluded) { |
366 | c10::impl::tls_set_dispatch_key_excluded( |
367 | c10::DispatchKey::FuncTorchBatched, excluded); |
368 | } |
369 | |
370 | static void _set_dynamic_layer_keys_included(bool value) { |
371 | return setDynamicLayerFrontBackKeysIncluded(value); |
372 | } |
373 | |
374 | static void dump_dls() { |
375 | std::cout << getDynamicLayerStack() << std::endl; |
376 | } |
377 | |
378 | static void dump_local_tls() { |
379 | auto tls = c10::impl::tls_local_dispatch_key_set(); |
380 | std::cout << "[Local Include] " << tls.included_ << std::endl; |
381 | std::cout << "[Local Exclude] " << tls.excluded_ << std::endl; |
382 | } |
383 | |
384 | static std::tuple<Tensor, c10::optional<int64_t>> unwrapBatched( |
385 | const Tensor& tensor, |
386 | int64_t level) { |
387 | auto* batched = maybeGetBatchedImpl(tensor); |
388 | if (!batched) { |
389 | return std::make_tuple(tensor, nullopt); |
390 | } |
391 | if (batched->level() == level) { |
392 | return std::make_tuple(batched->value(), batched->bdim()); |
393 | } |
394 | return std::make_tuple(tensor, nullopt); |
395 | } |
396 | |
397 | void initFuncTorchBindings(PyObject* module) { |
398 | auto _C = py::handle(module).cast<py::module>(); |
399 | auto m = _C.def_submodule("_functorch" ); |
400 | |
401 | m.def("_add_batch_dim" , &_add_batch_dim, "add batch dim" ); |
402 | m.def("_remove_batch_dim" , &_remove_batch_dim, "remove batch dim" ); |
403 | m.def("_unwrap_batched" , &unwrapBatched); |
404 | m.def( |
405 | "_wrap_functional_tensor" , |
406 | &_wrap_functional_tensor, |
407 | "add functional tensor" ); |
408 | m.def( |
409 | "_assert_wrapped_functional" , |
410 | &_assert_wrapped_functional, |
411 | "assert wrapped functional" ); |
412 | m.def( |
413 | "_propagate_functional_input_mutation" , |
414 | &_propagate_functional_input_mutation, |
415 | "propagate functional input mutations" ); |
416 | m.def( |
417 | "_unwrap_functional_tensor" , |
418 | &_unwrap_functional_tensor, |
419 | "remove functional tensor" ); |
420 | m.def("_vmap_increment_nesting" , &_vmap_increment_nesting); |
421 | m.def("_vmap_decrement_nesting" , &_vmap_decrement_nesting); |
422 | m.def( |
423 | "_func_increment_nesting" , |
424 | &_func_increment_nesting, |
425 | "functionalization start" ); |
426 | m.def( |
427 | "_func_decrement_nesting" , |
428 | &_func_decrement_nesting, |
429 | "functionalization end" ); |
430 | m.def("_grad_increment_nesting" , &_grad_increment_nesting); |
431 | m.def("_grad_decrement_nesting" , &_grad_decrement_nesting); |
432 | m.def("_jvp_increment_nesting" , &_jvp_increment_nesting); |
433 | m.def("_jvp_decrement_nesting" , &_jvp_decrement_nesting); |
434 | m.def("_wrap_for_grad" , &_wrap_for_grad, "wrap as gradtrackingtensor" ); |
435 | m.def( |
436 | "_unwrap_for_grad" , &_unwrap_for_grad, "unwrap from gradtrackingtensor" ); |
437 | m.def( |
438 | "_set_vmap_fallback_warning_enabled" , |
439 | &at::functorch::setVmapFallbackWarningEnabled, |
440 | "Set vmap fallback warnings" ); |
441 | m.def("_set_vmap_fallback_enabled" , &at::functorch::setVmapFallbackEnabled); |
442 | m.def("_is_vmap_fallback_enabled" , &at::functorch::isVmapFallbackEnabled); |
443 | m.def( |
444 | "set_inplace_requires_grad_allowed" , |
445 | &at::functorch::setInplaceRequiresGradAllowed); |
446 | m.def( |
447 | "get_inplace_requires_grad_allowed" , |
448 | &at::functorch::getInplaceRequiresGradAllowed); |
449 | m.def( |
450 | "set_single_level_autograd_function_allowed" , |
451 | &at::functorch::setSingleLevelAutogradFunctionAllowed); |
452 | m.def( |
453 | "get_single_level_autograd_function_allowed" , |
454 | &at::functorch::getSingleLevelAutogradFunctionAllowed); |
455 | m.def("unwrap_if_dead" , &unwrapIfDead); |
456 | m.def("is_dead_tensor_wrapper" , &isDeadTensorWrapper); |
457 | m.def("dlevel" , &dlevel, "dlevel" ); |
458 | m.def("dump_tensor" , &dump_tensor, "dump_tensor" ); |
459 | m.def("reshape_dim_into" , &at::functorch::reshape_dim_into); |
460 | m.def("reshape_dim_outof" , &at::functorch::reshape_dim_outof); |
461 | // various debugging things. Maybe we should offer these as first-class APIs |
462 | // on Tensors? |
463 | m.def("is_batchedtensor" , &is_batchedtensor); |
464 | m.def("is_gradtrackingtensor" , &is_gradtrackingtensor); |
465 | m.def("is_functionaltensor" , &is_functionaltensor); |
466 | m.def("get_unwrapped" , &get_unwrapped); |
467 | m.def("maybe_get_level" , &maybe_get_level); |
468 | m.def("maybe_get_bdim" , &maybe_get_bdim); |
469 | m.def("current_level" , ¤tLevel); |
470 | m.def("tls_set_vmap_excluded" , &tls_set_vmap_excluded); |
471 | m.def("_set_dynamic_layer_keys_included" , &_set_dynamic_layer_keys_included); |
472 | m.def("dump_dls" , &dump_dls); |
473 | m.def("dump_local_tls" , &dump_local_tls); |
474 | m.def("is_functorch_wrapped_tensor" , [](const Tensor& tensor) { |
475 | return maybe_get_level(tensor) != -1; |
476 | }); |
477 | m.def("peek_interpreter_stack" , []() -> c10::optional<Interpreter> { |
478 | const auto& stack = getDynamicLayerStack(); |
479 | if (stack.empty()) { |
480 | return c10::nullopt; |
481 | } |
482 | auto result = stack.back().interpreter(); |
483 | return result; |
484 | }); |
485 | m.def("pop_dynamic_layer_stack" , &popDynamicLayer); |
486 | m.def("push_dynamic_layer_stack" , [](DynamicLayer layer) -> int64_t { |
487 | return pushDynamicLayer(std::move(layer)); |
488 | }); |
489 | py::class_<DynamicLayer>(m, "DynamicLayer" ); |
490 | |
491 | py::enum_<TransformType>(m, "TransformType" ) |
492 | .value("Torch" , TransformType::Torch) |
493 | .value("Grad" , TransformType::Grad) |
494 | .value("Jvp" , TransformType::Jvp) |
495 | .value("Functionalize" , TransformType::Functionalize) |
496 | .value("Vmap" , TransformType::Vmap); |
497 | py::enum_<RandomnessType>(m, "RandomnessType" ) |
498 | .value("Error" , RandomnessType::Error) |
499 | .value("Same" , RandomnessType::Same) |
500 | .value("Different" , RandomnessType::Different); |
501 | py::class_<Interpreter>(m, "CInterpreter" ) |
502 | .def("key" , &Interpreter::key) |
503 | .def("level" , &Interpreter::level); |
504 | py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr" ) |
505 | .def(py::init<const Interpreter*>()) |
506 | .def("key" , &GradInterpreterPtr::key) |
507 | .def("level" , &GradInterpreterPtr::level) |
508 | .def("lift" , &GradInterpreterPtr::lift) |
509 | .def("prevGradMode" , &GradInterpreterPtr::prevGradMode); |
510 | py::class_<JvpInterpreterPtr>(m, "CJvpInterpreterPtr" ) |
511 | .def(py::init<const Interpreter*>()) |
512 | .def("key" , &JvpInterpreterPtr::key) |
513 | .def("level" , &JvpInterpreterPtr::level) |
514 | .def("lift" , &JvpInterpreterPtr::lift) |
515 | .def("prevFwdGradMode" , &JvpInterpreterPtr::prevFwdGradMode); |
516 | py::class_<VmapInterpreterPtr>(m, "CVmapInterpreterPtr" ) |
517 | .def(py::init<const Interpreter*>()) |
518 | .def("key" , &VmapInterpreterPtr::key) |
519 | .def("level" , &VmapInterpreterPtr::level) |
520 | .def("batchSize" , &VmapInterpreterPtr::batchSize) |
521 | .def("randomness" , &VmapInterpreterPtr::randomness); |
522 | py::class_<FunctionalizeInterpreterPtr>(m, "CFunctionalizeInterpreterPtr" ) |
523 | .def(py::init<const Interpreter*>()) |
524 | .def("key" , &FunctionalizeInterpreterPtr::key) |
525 | .def("level" , &FunctionalizeInterpreterPtr::level) |
526 | .def( |
527 | "functionalizeAddBackViews" , |
528 | &FunctionalizeInterpreterPtr::functionalizeAddBackViews); |
529 | } |
530 | |
531 | } // namespace impl |
532 | } // namespace functorch |
533 | } // namespace torch |
534 | |