1 | #pragma once |
---|---|
2 | #include <pybind11/pybind11.h> |
3 | #include <pybind11/stl.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/utils/pybind.h> |
6 | |
7 | namespace py = pybind11; |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | inline c10::optional<Module> as_module(py::handle obj) { |
13 | static py::handle ScriptModule = |
14 | py::module::import("torch.jit").attr( "ScriptModule"); |
15 | if (py::isinstance(obj, ScriptModule)) { |
16 | return py::cast<Module>(obj.attr("_c")); |
17 | } |
18 | return c10::nullopt; |
19 | } |
20 | |
21 | inline c10::optional<Object> as_object(py::handle obj) { |
22 | static py::handle ScriptObject = |
23 | py::module::import("torch").attr( "ScriptObject"); |
24 | if (py::isinstance(obj, ScriptObject)) { |
25 | return py::cast<Object>(obj); |
26 | } |
27 | |
28 | static py::handle RecursiveScriptClass = |
29 | py::module::import("torch.jit").attr( "RecursiveScriptClass"); |
30 | if (py::isinstance(obj, RecursiveScriptClass)) { |
31 | return py::cast<Object>(obj.attr("_c")); |
32 | } |
33 | return c10::nullopt; |
34 | } |
35 | |
36 | } // namespace jit |
37 | } // namespace torch |
38 |