1#pragma once
2
3#include <ATen/core/ivalue.h>
4#include <caffe2/serialize/inline_container.h>
5#include <torch/csrc/jit/api/module.h>
6#include <torch/csrc/jit/ir/ir.h>
7#include <torch/csrc/jit/serialization/unpickler.h>
8
9#include <istream>
10
11namespace caffe2 {
12namespace serialize {
13class ReadAdapterInterface;
14} // namespace serialize
15} // namespace caffe2
16
17namespace torch {
18namespace jit {
19
20TORCH_API Module import_ir_module(
21 std::shared_ptr<CompilationUnit> cu,
22 const std::string& filename,
23 c10::optional<c10::Device> device = c10::nullopt,
24 bool load_debug_files = true);
25
26TORCH_API Module import_ir_module(
27 std::shared_ptr<CompilationUnit> cu,
28 std::istream& in,
29 c10::optional<c10::Device> device = c10::nullopt,
30 bool load_debug_files = true);
31
32TORCH_API Module import_ir_module(
33 std::shared_ptr<CompilationUnit> cu,
34 std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
35 c10::optional<c10::Device> device = c10::nullopt,
36 bool load_debug_files = true);
37
38TORCH_API Module import_ir_module(
39 std::shared_ptr<CompilationUnit> cu,
40 const std::string& filename,
41 c10::optional<c10::Device> device,
42 ExtraFilesMap& extra_files,
43 bool load_debug_files = true,
44 bool restore_shapes = false);
45
46// For reading unified serialization format from torch.Package
47TORCH_API Module import_ir_module(
48 std::shared_ptr<CompilationUnit> cu,
49 std::shared_ptr<caffe2::serialize::PyTorchStreamReader> reader,
50 std::shared_ptr<torch::jit::DeserializationStorageContext> storage_context,
51 c10::optional<at::Device> device,
52 std::string ts_id /* torchscript identifier inside package */);
53
54TORCH_API Module import_ir_module(
55 std::shared_ptr<CompilationUnit> cu,
56 std::istream& in,
57 c10::optional<c10::Device> device,
58 ExtraFilesMap& extra_files,
59 bool load_debug_files = true,
60 bool restore_shapes = false);
61
62TORCH_API Module import_ir_module(
63 std::shared_ptr<CompilationUnit> cu,
64 std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
65 c10::optional<c10::Device> device,
66 ExtraFilesMap& extra_files,
67 bool load_debug_files = true);
68
69TORCH_API Module import_ir_module(
70 std::shared_ptr<CompilationUnit> cu,
71 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
72 c10::optional<c10::Device> device,
73 ExtraFilesMap& extra_files,
74 bool load_debug_files = true);
75
76/// Loads a serialized `Module` from the given `istream`.
77///
78/// The istream must contain a serialized `Module`, exported via
79/// `torch::jit::ExportModule` in C++.
80TORCH_API Module load(
81 std::istream& in,
82 c10::optional<c10::Device> device = c10::nullopt,
83 bool load_debug_files = true);
84
85TORCH_API Module load(
86 std::istream& in,
87 c10::optional<c10::Device> device,
88 ExtraFilesMap& extra_files,
89 bool load_debug_files = true);
90
91/// Loads a serialized `Module` from the given `filename`.
92///
93/// The file stored at the location given in `filename` must contain a
94/// serialized `Module`, exported either via `ScriptModule.save()` in
95/// Python or `torch::jit::ExportModule` in C++.
96TORCH_API Module load(
97 const std::string& filename,
98 c10::optional<c10::Device> device = c10::nullopt,
99 bool load_debug_files = true);
100
101TORCH_API Module load(
102 const std::string& filename,
103 c10::optional<c10::Device> device,
104 ExtraFilesMap& extra_files,
105 bool load_debug_files = true);
106
107/// Loads a serialized `Module` from the given shared_ptr `rai`.
108///
109/// The reader adapter, which is for customized input stream, must contain a
110/// serialized `Module`, exported either via `ScriptModule.save()` in
111/// Python or `torch::jit::ExportModule` in C++.
112TORCH_API Module load(
113 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
114 c10::optional<c10::Device> device = c10::nullopt,
115 bool load_debug_files = true);
116
117TORCH_API Module load(
118 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai,
119 c10::optional<c10::Device> device,
120 ExtraFilesMap& extra_files,
121 bool load_debug_files = true);
122
123TORCH_API Module jitModuleFromSourceAndConstants(
124 const IValue& ivalue,
125 const ExtraFilesMap& source,
126 const std::vector<IValue>& constants,
127 int32_t version);
128
129TORCH_API Module parse_and_initialize_jit_module(
130 std::shared_ptr<char> data,
131 size_t size,
132 ExtraFilesMap& extra_files,
133 c10::optional<at::Device> device = c10::nullopt);
134
135TORCH_API Module load_jit_module_from_file(
136 const std::string& filename,
137 ExtraFilesMap& extra_files,
138 c10::optional<at::Device> device = c10::nullopt);
139
140TORCH_API Module load_jit_module_from_stream(
141 std::istream& in,
142 ExtraFilesMap& extra_files,
143 c10::optional<at::Device> device = c10::nullopt);
144
145TORCH_API Module parse_and_initialize_jit_module(
146 std::shared_ptr<char> data,
147 size_t size,
148 ExtraFilesMap& extra_files,
149 c10::optional<at::Device> device);
150
151} // namespace jit
152} // namespace torch
153