1#include <torch/csrc/python_headers.h>
2
3#include <libshm.h>
4#include <cstdlib>
5
6#include <pybind11/detail/common.h>
7#include <pybind11/functional.h>
8#include <pybind11/pybind11.h>
9#include <pybind11/pytypes.h>
10#include <pybind11/stl.h>
11#include <pybind11/stl_bind.h>
12#include <torch/csrc/utils/pybind.h>
13
14#include <Python.h> // NOLINT
15#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
16#include <torch/csrc/jit/python/module_python.h>
17#include <torch/csrc/jit/python/python_ivalue.h>
18#include <torch/csrc/jit/python/python_sugared_value.h>
19#include <torch/csrc/jit/serialization/export.h>
20#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
21#include <torch/csrc/jit/serialization/import.h>
22
23namespace py = pybind11;
24
25using torch::jit::kFlatbufferDataAlignmentBytes;
26
27static std::shared_ptr<char> copyStr(const std::string& bytes) {
28 size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
29 kFlatbufferDataAlignmentBytes;
30#ifdef _WIN32
31 std::shared_ptr<char> bytes_copy(
32 static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
33 _aligned_free);
34#elif defined(__APPLE__)
35 void* p;
36 ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
37 TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
38 std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
39#else
40 std::shared_ptr<char> bytes_copy(
41 static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
42 free);
43#endif
44 memcpy(bytes_copy.get(), bytes.data(), bytes.size());
45 return bytes_copy;
46}
47
48extern "C"
49#ifdef _WIN32
50 __declspec(dllexport)
51#endif
52 PyObject* initModuleFlatbuffer() {
53 using namespace torch::jit;
54 PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
55 static struct PyModuleDef torchmodule = {
56 PyModuleDef_HEAD_INIT,
57 "torch._C_flatbuffer",
58 nullptr,
59 -1,
60 m,
61 }; // NOLINT
62 PyObject* module = PyModule_Create(&torchmodule);
63 auto pym = py::handle(module).cast<py::module>();
64 pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
65 return torch::jit::load_mobile_module_from_file(filename);
66 });
67 pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
68 auto bytes_copy = copyStr(bytes);
69 return torch::jit::parse_and_initialize_mobile_module(
70 bytes_copy, bytes.size());
71 });
72 pym.def("_load_jit_module_from_file", [](const std::string& filename) {
73 ExtraFilesMap extra_files = ExtraFilesMap();
74 return torch::jit::load_jit_module_from_file(filename, extra_files);
75 });
76 pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
77 auto bytes_copy = copyStr(bytes);
78 ExtraFilesMap extra_files = ExtraFilesMap();
79 return torch::jit::parse_and_initialize_jit_module(
80 bytes_copy, bytes.size(), extra_files);
81 });
82 pym.def(
83 "_save_mobile_module",
84 [](const torch::jit::mobile::Module& module,
85 const std::string& filename,
86 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
87 return torch::jit::save_mobile_module(module, filename, _extra_files);
88 });
89 pym.def(
90 "_save_jit_module",
91 [](const torch::jit::Module& module,
92 const std::string& filename,
93 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
94 return torch::jit::save_jit_module(module, filename, _extra_files);
95 });
96 pym.def(
97 "_save_mobile_module_to_bytes",
98 [](const torch::jit::mobile::Module& module,
99 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
100 auto detached_buffer =
101 torch::jit::save_mobile_module_to_bytes(module, _extra_files);
102 return py::bytes(
103 reinterpret_cast<char*>(detached_buffer->data()),
104 detached_buffer->size());
105 });
106 pym.def(
107 "_save_jit_module_to_bytes",
108 [](const torch::jit::Module& module,
109 const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
110 auto detached_buffer =
111 torch::jit::save_jit_module_to_bytes(module, _extra_files);
112 return py::bytes(
113 reinterpret_cast<char*>(detached_buffer->data()),
114 detached_buffer->size());
115 });
116 pym.def(
117 "_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
118 py::gil_scoped_acquire acquire;
119 py::dict result;
120 mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
121 flatbuffer_content.data());
122 result["bytecode_version"] = minfo.bytecode_version;
123 result["operator_version"] = minfo.operator_version;
124 result["function_names"] = minfo.function_names;
125 result["type_names"] = minfo.type_names;
126 result["opname_to_num_args"] = minfo.opname_to_num_args;
127 return result;
128 });
129
130 return module;
131}
132