1#pragma once
2#include <pybind11/pybind11.h>
3#include <pybind11/stl.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/utils/pybind.h>
6
7namespace py = pybind11;
8
9namespace torch {
10namespace jit {
11
12inline c10::optional<Module> as_module(py::handle obj) {
13 static py::handle ScriptModule =
14 py::module::import("torch.jit").attr("ScriptModule");
15 if (py::isinstance(obj, ScriptModule)) {
16 return py::cast<Module>(obj.attr("_c"));
17 }
18 return c10::nullopt;
19}
20
21inline c10::optional<Object> as_object(py::handle obj) {
22 static py::handle ScriptObject =
23 py::module::import("torch").attr("ScriptObject");
24 if (py::isinstance(obj, ScriptObject)) {
25 return py::cast<Object>(obj);
26 }
27
28 static py::handle RecursiveScriptClass =
29 py::module::import("torch.jit").attr("RecursiveScriptClass");
30 if (py::isinstance(obj, RecursiveScriptClass)) {
31 return py::cast<Object>(obj.attr("_c"));
32 }
33 return c10::nullopt;
34}
35
36} // namespace jit
37} // namespace torch
38