1#pragma once
2#include <ATen/core/jit_type.h>
3#include <torch/csrc/jit/mobile/debug_info.h>
4#include <torch/csrc/jit/mobile/function.h>
5#include <torch/csrc/jit/mobile/method.h>
6#include <torch/csrc/jit/mobile/quantization.h>
7
8namespace torch {
9namespace jit {
10namespace mobile {
11using Stack = std::vector<c10::IValue>;
12
13// A CompilationUnit object is the one that gets executed by the lite
14// interpreter.
15//
16// A CompilationUnit object contains a list of Method Objects. These are methods
17// that appear in the original PyTorch Model. These method correspond to Python
18// member functions of the Model class.
19//
20// Methods in turn contain a Function, and a back-pointer to the Module that
21// owns this Method instance.
22//
23// A Function contains a Code Object (code_) which is defined in interpreter.h
24//
25// A Code object contains the following:
26//
27// std::vector<Instruction> instructions_;
28// std::vector<c10::OperatorName> op_names_;
29// std::vector<std::function<void(Stack&)>> operators_;
30// std::vector<c10::IValue> constants_;
31// std::vector<c10::TypePtr> types_;
32// size_t register_size_; // Aggregated output size.
33//
34class CompilationUnit {
35 public:
36 void register_function(std::unique_ptr<Function> fn);
37 std::vector<std::unique_ptr<Function>>& methods() {
38 return methods_;
39 }
40 const std::vector<std::unique_ptr<Function>>& methods() const {
41 return methods_;
42 }
43 Function* find_function(const c10::QualifiedName& qn);
44 const Function* find_function(const c10::QualifiedName& qn) const;
45
46 void unsafeRemoveFunction(const int64_t index) {
47 methods_.erase(methods_.begin() + index);
48 }
49
50 private:
51 std::vector<std::unique_ptr<Function>> methods_;
52};
53
54// A Torch Mobile Module is a representation of the model (trained in case
55// of inference). A Mobile Module contains
56//
57// 1. data (object_)
58// 2. metadata (optional) about the model (metadata_ from the metadata.pkl
59// file added after training)
60// 3. Compilation Unit (cu_)
61//
62class TORCH_API Module {
63 public:
64 Module(
65 // NOLINTNEXTLINE(modernize-pass-by-value)
66 c10::intrusive_ptr<c10::ivalue::Object> object,
67 std::shared_ptr<CompilationUnit> cu)
68 : object_(object), cu_(std::move(cu)) {}
69 Module() = default;
70 Method get_method(const std::string& method_name) const;
71 template <typename... Types>
72 c10::IValue run_method(const std::string& method_name, Types&&... args) {
73 return get_method(method_name)({IValue(std::forward<Types>(args))...});
74 }
75 c10::IValue forward(std::vector<c10::IValue> inputs) {
76 return get_method("forward")(std::move(inputs));
77 }
78 c10::optional<Method> find_method(const std::string& basename) const;
79
80 const std::string name() const {
81 return object_->name();
82 }
83 const std::vector<at::IValue>& slots() const {
84 return object_->slots();
85 }
86 const c10::intrusive_ptr<c10::ivalue::Object> _ivalue() const {
87 return object_;
88 }
89 const std::vector<at::Tensor> parameters() const;
90 const std::map<std::string, at::Tensor> named_parameters() const;
91 std::string get_forward_method_debug_info(int64_t debug_handle) const;
92 std::string getModuleHierarchy(const int64_t debug_handle) const;
93 std::string getCallStack(const int64_t debug_handle) const;
94 /// Enables "training" mode.
95 void train(bool on = true);
96 /// Calls train(false) to enable "eval" mode.
97 void eval() {
98 train(/*on=*/false);
99 }
100 /// True if the module is in training mode.
101 bool is_training() const;
102 const std::unordered_map<std::string, std::string> getMetadata() const {
103 return metadata_;
104 }
105 void setMetadata(
106 const std::unordered_map<std::string, std::string>& metadata) {
107 metadata_ = metadata;
108 }
109 const std::vector<Method> get_methods() const;
110
111 c10::IValue attr(const std::string& name, c10::IValue or_else) const {
112 if (auto r = object_->type()->findAttributeSlot(name)) {
113 return object_->getSlot(*r);
114 }
115 if (auto r = object_->type()->findConstantSlot(name)) {
116 return object_->type()->getConstant(*r);
117 }
118 return or_else;
119 }
120
121 void setDebugTable(MobileDebugTable&& debug_table) {
122 debug_table_ = std::move(debug_table);
123 }
124 const MobileDebugTable& getDebugTable() const {
125 return debug_table_;
126 }
127
128 void setHasDebugHandles(bool has_debug_handles) {
129 has_debug_handles_ = has_debug_handles;
130 }
131
132 bool hasDebugHandles() const {
133 return has_debug_handles_;
134 }
135
136 const CompilationUnit& compilation_unit() const {
137 return *cu_.get();
138 }
139
140 void set_delete_memory(std::shared_ptr<char> delete_mem) {
141 mem_to_delete_ = delete_mem;
142 }
143
144 void set_min_operator_version(int64_t version) {
145 min_operator_version_ = version;
146 }
147
148 int64_t min_operator_version() const {
149 return min_operator_version_;
150 }
151
152 void set_bytecode_version(int64_t version) {
153 bytecode_version_ = version;
154 }
155
156 int64_t bytecode_version() const {
157 return bytecode_version_;
158 }
159
160 private:
161 friend class quantization::PTQQuanizationHelper;
162
163 bool compareMethodSchemas(
164 const std::string& name_1,
165 const std::string& name_2);
166
167 void unsafeRemoveMethod(const std::string& basename);
168
169 void unsafeCopyMethod(
170 const std::string& new_method_name,
171 const Function& to_be_copied);
172
173 c10::intrusive_ptr<c10::ivalue::Object> object_;
174 std::unordered_map<std::string, std::string> metadata_;
175 std::shared_ptr<CompilationUnit> cu_;
176 MobileDebugTable debug_table_;
177 bool has_debug_handles_ = false;
178 int64_t min_operator_version_ = 4;
179 int64_t bytecode_version_ = 4;
180
181 // Extra handle for the module to delete when itself is deleted
182 std::shared_ptr<char> mem_to_delete_;
183};
184
185struct TORCH_API ModuleInfo {
186 uint64_t bytecode_version;
187 uint64_t operator_version;
188 std::unordered_map<std::string, int> opname_to_num_args;
189 std::unordered_set<std::string> function_names;
190 std::unordered_set<std::string> type_names;
191};
192TORCH_API ModuleInfo get_module_info(const mobile::Module& module);
193
194} // namespace mobile
195} // namespace jit
196} // namespace torch
197