1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <libshm.h> |
4 | #include <cstdlib> |
5 | |
6 | #include <pybind11/detail/common.h> |
7 | #include <pybind11/functional.h> |
8 | #include <pybind11/pybind11.h> |
9 | #include <pybind11/pytypes.h> |
10 | #include <pybind11/stl.h> |
11 | #include <pybind11/stl_bind.h> |
12 | #include <torch/csrc/utils/pybind.h> |
13 | |
14 | #include <Python.h> // NOLINT |
15 | #include <torch/csrc/jit/mobile/flatbuffer_loader.h> |
16 | #include <torch/csrc/jit/python/module_python.h> |
17 | #include <torch/csrc/jit/python/python_ivalue.h> |
18 | #include <torch/csrc/jit/python/python_sugared_value.h> |
19 | #include <torch/csrc/jit/serialization/export.h> |
20 | #include <torch/csrc/jit/serialization/flatbuffer_serializer.h> |
21 | #include <torch/csrc/jit/serialization/import.h> |
22 | |
23 | namespace py = pybind11; |
24 | |
25 | using torch::jit::kFlatbufferDataAlignmentBytes; |
26 | |
27 | static std::shared_ptr<char> copyStr(const std::string& bytes) { |
28 | size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) * |
29 | kFlatbufferDataAlignmentBytes; |
30 | #ifdef _WIN32 |
31 | std::shared_ptr<char> bytes_copy( |
32 | static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)), |
33 | _aligned_free); |
34 | #elif defined(__APPLE__) |
35 | void* p; |
36 | ::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size); |
37 | TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer" ); |
38 | std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free); |
39 | #else |
40 | std::shared_ptr<char> bytes_copy( |
41 | static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)), |
42 | free); |
43 | #endif |
44 | memcpy(bytes_copy.get(), bytes.data(), bytes.size()); |
45 | return bytes_copy; |
46 | } |
47 | |
48 | extern "C" |
49 | #ifdef _WIN32 |
50 | __declspec(dllexport) |
51 | #endif |
52 | PyObject* initModuleFlatbuffer() { |
53 | using namespace torch::jit; |
54 | PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT |
55 | static struct PyModuleDef torchmodule = { |
56 | PyModuleDef_HEAD_INIT, |
57 | "torch._C_flatbuffer" , |
58 | nullptr, |
59 | -1, |
60 | m, |
61 | }; // NOLINT |
62 | PyObject* module = PyModule_Create(&torchmodule); |
63 | auto pym = py::handle(module).cast<py::module>(); |
64 | pym.def("_load_mobile_module_from_file" , [](const std::string& filename) { |
65 | return torch::jit::load_mobile_module_from_file(filename); |
66 | }); |
67 | pym.def("_load_mobile_module_from_bytes" , [](const std::string& bytes) { |
68 | auto bytes_copy = copyStr(bytes); |
69 | return torch::jit::parse_and_initialize_mobile_module( |
70 | bytes_copy, bytes.size()); |
71 | }); |
72 | pym.def("_load_jit_module_from_file" , [](const std::string& filename) { |
73 | ExtraFilesMap = ExtraFilesMap(); |
74 | return torch::jit::load_jit_module_from_file(filename, extra_files); |
75 | }); |
76 | pym.def("_load_jit_module_from_bytes" , [](const std::string& bytes) { |
77 | auto bytes_copy = copyStr(bytes); |
78 | ExtraFilesMap = ExtraFilesMap(); |
79 | return torch::jit::parse_and_initialize_jit_module( |
80 | bytes_copy, bytes.size(), extra_files); |
81 | }); |
82 | pym.def( |
83 | "_save_mobile_module" , |
84 | [](const torch::jit::mobile::Module& module, |
85 | const std::string& filename, |
86 | const ExtraFilesMap& = ExtraFilesMap()) { |
87 | return torch::jit::save_mobile_module(module, filename, _extra_files); |
88 | }); |
89 | pym.def( |
90 | "_save_jit_module" , |
91 | [](const torch::jit::Module& module, |
92 | const std::string& filename, |
93 | const ExtraFilesMap& = ExtraFilesMap()) { |
94 | return torch::jit::save_jit_module(module, filename, _extra_files); |
95 | }); |
96 | pym.def( |
97 | "_save_mobile_module_to_bytes" , |
98 | [](const torch::jit::mobile::Module& module, |
99 | const ExtraFilesMap& = ExtraFilesMap()) { |
100 | auto detached_buffer = |
101 | torch::jit::save_mobile_module_to_bytes(module, _extra_files); |
102 | return py::bytes( |
103 | reinterpret_cast<char*>(detached_buffer->data()), |
104 | detached_buffer->size()); |
105 | }); |
106 | pym.def( |
107 | "_save_jit_module_to_bytes" , |
108 | [](const torch::jit::Module& module, |
109 | const ExtraFilesMap& = ExtraFilesMap()) { |
110 | auto detached_buffer = |
111 | torch::jit::save_jit_module_to_bytes(module, _extra_files); |
112 | return py::bytes( |
113 | reinterpret_cast<char*>(detached_buffer->data()), |
114 | detached_buffer->size()); |
115 | }); |
116 | pym.def( |
117 | "_get_module_info_from_flatbuffer" , [](std::string flatbuffer_content) { |
118 | py::gil_scoped_acquire acquire; |
119 | py::dict result; |
120 | mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer( |
121 | flatbuffer_content.data()); |
122 | result["bytecode_version" ] = minfo.bytecode_version; |
123 | result["operator_version" ] = minfo.operator_version; |
124 | result["function_names" ] = minfo.function_names; |
125 | result["type_names" ] = minfo.type_names; |
126 | result["opname_to_num_args" ] = minfo.opname_to_num_args; |
127 | return result; |
128 | }); |
129 | |
130 | return module; |
131 | } |
132 | |