1#pragma once
2#include <torch/csrc/jit/mobile/module.h>
3#include <torch/csrc/jit/mobile/parse_operators.h>
4
5#include <istream>
6#include <memory>
7
8#include <caffe2/serialize/file_adapter.h>
9
10namespace torch {
11namespace jit {
12using caffe2::serialize::FileAdapter;
13using caffe2::serialize::IStreamAdapter;
14using caffe2::serialize::ReadAdapterInterface;
15using ExtraFilesMap = std::unordered_map<std::string, std::string>;
16
17constexpr const char* kArchiveNameBytecode = "bytecode";
18constexpr const char* kArchiveNameConstants = "constants";
19constexpr const char* kArchiveNameVersion = "version";
20
21// The family of methods below load a serialized Mobile Module
22// into a mobile::Module object.
23TORCH_API mobile::Module _load_for_mobile(
24 std::istream& in,
25 c10::optional<at::Device> device,
26 ExtraFilesMap& extra_files);
27
28TORCH_API mobile::Module _load_for_mobile(
29 const std::string& filename,
30 c10::optional<at::Device> device,
31 ExtraFilesMap& extra_files);
32
33TORCH_API mobile::Module _load_for_mobile(
34 std::unique_ptr<ReadAdapterInterface> rai,
35 c10::optional<c10::Device> device,
36 ExtraFilesMap& extra_files,
37 uint64_t module_load_options = kDefaultMobileLoadOptions);
38
39TORCH_API mobile::Module _load_for_mobile(
40 const std::string& filename,
41 c10::optional<at::Device> device,
42 ExtraFilesMap& extra_files,
43 uint64_t module_load_options);
44
45TORCH_API mobile::Module _load_for_mobile(
46 std::istream& in,
47 c10::optional<at::Device> device = c10::nullopt);
48
49TORCH_API mobile::Module _load_for_mobile(
50 const std::string& filename,
51 c10::optional<at::Device> device = c10::nullopt);
52
53TORCH_API mobile::Module _load_for_mobile(
54 std::unique_ptr<ReadAdapterInterface> rai,
55 c10::optional<c10::Device> device = c10::nullopt);
56
57/**
58 * Load only the contents of the "extra/" files whose names are
59 * passed in the map (extra_files). Populate the corresponding values
60 * with the contents of those files. Do not attempt to load the entire
61 * model, and stop once the extra files have been extracted.
62 *
63 * This API is needed to be able to load GPU models on linux CPU
64 * machines and extract only the extra files so that we can inspect
65 * the metadata that was added to the .ptl archive when it was
66 * generated.
67 *
68 */
69void _load_extra_only_for_mobile(
70 const std::string& filename,
71 c10::optional<at::Device> device,
72 ExtraFilesMap& extra_files);
73
74// Currently used by both mobile/import.cpp and model_compatibility.cpp.
75// Should be removed after model_compatibility.cpp start using simplified
76// version type_resolver and obj_loader.
77at::TypePtr resolveTypeNameMobile(
78 const c10::QualifiedName& qn,
79 std::shared_ptr<CompilationUnit> compilation_unit);
80c10::StrongTypePtr typeResolverMobile(
81 const c10::QualifiedName& qn,
82 std::shared_ptr<CompilationUnit> compilation_unit);
83c10::intrusive_ptr<c10::ivalue::Object> objLoaderMobile(
84 const at::StrongTypePtr& type,
85 const at::IValue& input,
86 mobile::CompilationUnit& mobile_compilation_unit);
87
88// Given a reader, which has access to a model file,
89// return true if there exists tensors in `bytecode` archive
90bool isTensorInBytecodeArchive(
91 caffe2::serialize::PyTorchStreamReader& stream_reader);
92
93namespace mobile {
94
95/**
96 * Given a torch::jit::mobile::Module, return a set of operator names
97 * (with overload name) that are used by any method in this mobile
98 * Mobile. This method runs through the bytecode for all methods
99 * in the specified model (module), and extracts all the root
100 * operator names. Root operators are operators that are called
101 * directly by the model (as opposed to non-root operators, which
102 * may be called transitively by the root operators).
103 *
104 */
105TORCH_API std::set<std::string> _export_operator_list(
106 torch::jit::mobile::Module& module);
107
108} // namespace mobile
109
110} // namespace jit
111} // namespace torch
112