1 | #pragma once |
2 | #include <ATen/core/jit_type.h> |
3 | #include <torch/csrc/jit/mobile/debug_info.h> |
4 | #include <torch/csrc/jit/mobile/function.h> |
5 | #include <torch/csrc/jit/mobile/method.h> |
6 | #include <torch/csrc/jit/mobile/quantization.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace mobile { |
11 | using Stack = std::vector<c10::IValue>; |
12 | |
13 | // A CompilationUnit object is the one that gets executed by the lite |
14 | // interpreter. |
15 | // |
16 | // A CompilationUnit object contains a list of Method Objects. These are methods |
17 | // that appear in the original PyTorch Model. These method correspond to Python |
18 | // member functions of the Model class. |
19 | // |
20 | // Methods in turn contain a Function, and a back-pointer to the Module that |
21 | // owns this Method instance. |
22 | // |
23 | // A Function contains a Code Object (code_) which is defined in interpreter.h |
24 | // |
25 | // A Code object contains the following: |
26 | // |
27 | // std::vector<Instruction> instructions_; |
28 | // std::vector<c10::OperatorName> op_names_; |
29 | // std::vector<std::function<void(Stack&)>> operators_; |
30 | // std::vector<c10::IValue> constants_; |
31 | // std::vector<c10::TypePtr> types_; |
32 | // size_t register_size_; // Aggregated output size. |
33 | // |
34 | class CompilationUnit { |
35 | public: |
36 | void register_function(std::unique_ptr<Function> fn); |
37 | std::vector<std::unique_ptr<Function>>& methods() { |
38 | return methods_; |
39 | } |
40 | const std::vector<std::unique_ptr<Function>>& methods() const { |
41 | return methods_; |
42 | } |
43 | Function* find_function(const c10::QualifiedName& qn); |
44 | const Function* find_function(const c10::QualifiedName& qn) const; |
45 | |
46 | void unsafeRemoveFunction(const int64_t index) { |
47 | methods_.erase(methods_.begin() + index); |
48 | } |
49 | |
50 | private: |
51 | std::vector<std::unique_ptr<Function>> methods_; |
52 | }; |
53 | |
54 | // A Torch Mobile Module is a representation of the model (trained in case |
55 | // of inference). A Mobile Module contains |
56 | // |
57 | // 1. data (object_) |
58 | // 2. metadata (optional) about the model (metadata_ from the metadata.pkl |
59 | // file added after training) |
60 | // 3. Compilation Unit (cu_) |
61 | // |
62 | class TORCH_API Module { |
63 | public: |
64 | Module( |
65 | // NOLINTNEXTLINE(modernize-pass-by-value) |
66 | c10::intrusive_ptr<c10::ivalue::Object> object, |
67 | std::shared_ptr<CompilationUnit> cu) |
68 | : object_(object), cu_(std::move(cu)) {} |
69 | Module() = default; |
70 | Method get_method(const std::string& method_name) const; |
71 | template <typename... Types> |
72 | c10::IValue run_method(const std::string& method_name, Types&&... args) { |
73 | return get_method(method_name)({IValue(std::forward<Types>(args))...}); |
74 | } |
75 | c10::IValue forward(std::vector<c10::IValue> inputs) { |
76 | return get_method("forward" )(std::move(inputs)); |
77 | } |
78 | c10::optional<Method> find_method(const std::string& basename) const; |
79 | |
80 | const std::string name() const { |
81 | return object_->name(); |
82 | } |
83 | const std::vector<at::IValue>& slots() const { |
84 | return object_->slots(); |
85 | } |
86 | const c10::intrusive_ptr<c10::ivalue::Object> _ivalue() const { |
87 | return object_; |
88 | } |
89 | const std::vector<at::Tensor> parameters() const; |
90 | const std::map<std::string, at::Tensor> named_parameters() const; |
91 | std::string get_forward_method_debug_info(int64_t debug_handle) const; |
92 | std::string getModuleHierarchy(const int64_t debug_handle) const; |
93 | std::string getCallStack(const int64_t debug_handle) const; |
94 | /// Enables "training" mode. |
95 | void train(bool on = true); |
96 | /// Calls train(false) to enable "eval" mode. |
97 | void eval() { |
98 | train(/*on=*/false); |
99 | } |
100 | /// True if the module is in training mode. |
101 | bool is_training() const; |
102 | const std::unordered_map<std::string, std::string> getMetadata() const { |
103 | return metadata_; |
104 | } |
105 | void setMetadata( |
106 | const std::unordered_map<std::string, std::string>& metadata) { |
107 | metadata_ = metadata; |
108 | } |
109 | const std::vector<Method> get_methods() const; |
110 | |
111 | c10::IValue attr(const std::string& name, c10::IValue or_else) const { |
112 | if (auto r = object_->type()->findAttributeSlot(name)) { |
113 | return object_->getSlot(*r); |
114 | } |
115 | if (auto r = object_->type()->findConstantSlot(name)) { |
116 | return object_->type()->getConstant(*r); |
117 | } |
118 | return or_else; |
119 | } |
120 | |
121 | void setDebugTable(MobileDebugTable&& debug_table) { |
122 | debug_table_ = std::move(debug_table); |
123 | } |
124 | const MobileDebugTable& getDebugTable() const { |
125 | return debug_table_; |
126 | } |
127 | |
128 | void setHasDebugHandles(bool has_debug_handles) { |
129 | has_debug_handles_ = has_debug_handles; |
130 | } |
131 | |
132 | bool hasDebugHandles() const { |
133 | return has_debug_handles_; |
134 | } |
135 | |
136 | const CompilationUnit& compilation_unit() const { |
137 | return *cu_.get(); |
138 | } |
139 | |
140 | void set_delete_memory(std::shared_ptr<char> delete_mem) { |
141 | mem_to_delete_ = delete_mem; |
142 | } |
143 | |
144 | void set_min_operator_version(int64_t version) { |
145 | min_operator_version_ = version; |
146 | } |
147 | |
148 | int64_t min_operator_version() const { |
149 | return min_operator_version_; |
150 | } |
151 | |
152 | void set_bytecode_version(int64_t version) { |
153 | bytecode_version_ = version; |
154 | } |
155 | |
156 | int64_t bytecode_version() const { |
157 | return bytecode_version_; |
158 | } |
159 | |
160 | private: |
161 | friend class quantization::PTQQuanizationHelper; |
162 | |
163 | bool compareMethodSchemas( |
164 | const std::string& name_1, |
165 | const std::string& name_2); |
166 | |
167 | void unsafeRemoveMethod(const std::string& basename); |
168 | |
169 | void unsafeCopyMethod( |
170 | const std::string& new_method_name, |
171 | const Function& to_be_copied); |
172 | |
173 | c10::intrusive_ptr<c10::ivalue::Object> object_; |
174 | std::unordered_map<std::string, std::string> metadata_; |
175 | std::shared_ptr<CompilationUnit> cu_; |
176 | MobileDebugTable debug_table_; |
177 | bool has_debug_handles_ = false; |
178 | int64_t min_operator_version_ = 4; |
179 | int64_t bytecode_version_ = 4; |
180 | |
181 | // Extra handle for the module to delete when itself is deleted |
182 | std::shared_ptr<char> mem_to_delete_; |
183 | }; |
184 | |
185 | struct TORCH_API ModuleInfo { |
186 | uint64_t bytecode_version; |
187 | uint64_t operator_version; |
188 | std::unordered_map<std::string, int> opname_to_num_args; |
189 | std::unordered_set<std::string> function_names; |
190 | std::unordered_set<std::string> type_names; |
191 | }; |
192 | TORCH_API ModuleInfo get_module_info(const mobile::Module& module); |
193 | |
194 | } // namespace mobile |
195 | } // namespace jit |
196 | } // namespace torch |
197 | |