1 | #include <ATen/core/ivalue.h> |
2 | #include <caffe2/serialize/file_adapter.h> |
3 | #include <caffe2/serialize/inline_container.h> |
4 | #include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader |
5 | #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h> |
6 | #include <torch/csrc/jit/mobile/file_format.h> |
7 | #include <torch/csrc/jit/mobile/flatbuffer_loader.h> |
8 | #include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader |
9 | #include <torch/csrc/jit/mobile/type_parser.h> |
10 | #include <torch/csrc/jit/serialization/import_export_constants.h> |
11 | #include <torch/csrc/jit/serialization/import_read.h> |
12 | |
13 | #include <caffe2/serialize/in_memory_adapter.h> |
14 | #include <sstream> |
15 | #include <string> |
16 | #include <unordered_set> |
17 | #include <vector> |
18 | |
19 | namespace c10 { |
20 | TypePtr parseType(const std::string& pythonStr); |
21 | } // namespace c10 |
22 | |
23 | namespace torch { |
24 | namespace jit { |
25 | |
26 | using caffe2::serialize::FileAdapter; |
27 | using caffe2::serialize::IStreamAdapter; |
28 | using caffe2::serialize::PyTorchStreamReader; |
29 | using caffe2::serialize::ReadAdapterInterface; |
30 | |
31 | c10::IValue readArchive( |
32 | const std::string& archive_name, |
33 | PyTorchStreamReader& stream_reader) { |
34 | c10::optional<at::Device> device; |
35 | std::shared_ptr<CompilationUnit> compilation_unit = |
36 | std::make_shared<CompilationUnit>(); |
37 | |
38 | // TODO (T90180710): Simplify type_resolver and obj_loader when getting |
39 | // bytecode version from model |
40 | auto type_resolver = [&](const c10::QualifiedName& qn) { |
41 | return typeResolverMobile(qn, compilation_unit); |
42 | }; |
43 | |
44 | std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit = |
45 | std::make_shared<mobile::CompilationUnit>(); |
46 | auto obj_loader = [&](at::StrongTypePtr type, IValue input) { |
47 | return objLoaderMobile(type, input, *mobile_compilation_unit); |
48 | }; |
49 | bool bytecode_tensor_in_constants_archive = |
50 | (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader)); |
51 | auto ivalues = torch::jit::readArchiveAndTensors( |
52 | archive_name, |
53 | /*pickle_prefix=*/"" , |
54 | /*tensor_prefix=*/ |
55 | bytecode_tensor_in_constants_archive ? "constants/" : "" , |
56 | type_resolver, |
57 | obj_loader, |
58 | device, |
59 | stream_reader, |
60 | nullptr); |
61 | return ivalues; |
62 | } |
63 | |
64 | std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) { |
65 | return std::move(*readArchive("bytecode" , reader).toTuple()).elements().vec(); |
66 | } |
67 | |
68 | /********************** Bytecode **********************/ |
69 | |
70 | // Forward declare |
71 | uint64_t _get_model_bytecode_version( |
72 | const std::vector<IValue>& bytecode_ivalues); |
73 | static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size); |
74 | |
75 | uint64_t _get_model_bytecode_version(std::istream& in) { |
76 | auto orig_pos = in.tellg(); |
77 | in.seekg(0, in.beg); |
78 | std::shared_ptr<char> data; |
79 | size_t size = 0; |
80 | std::tie(data, size) = get_stream_content(in); |
81 | in.seekg(orig_pos, in.beg); |
82 | return _get_model_bytecode_version_from_bytes(data.get(), size); |
83 | } |
84 | |
85 | uint64_t _get_model_bytecode_version(const std::string& filename) { |
86 | std::ifstream ifile(filename); |
87 | return _get_model_bytecode_version(ifile); |
88 | } |
89 | |
90 | uint64_t _get_model_bytecode_version( |
91 | std::shared_ptr<ReadAdapterInterface> rai) { |
92 | std::shared_ptr<char> data; |
93 | size_t size = 0; |
94 | std::tie(data, size) = get_rai_content(rai.get()); |
95 | return _get_model_bytecode_version_from_bytes(data.get(), size); |
96 | } |
97 | |
98 | uint64_t _get_model_bytecode_version_zip( |
99 | std::shared_ptr<ReadAdapterInterface> rai) { |
100 | if (!check_zip_file(rai)) { |
101 | TORCH_CHECK( |
102 | false, |
103 | "Failed to open .ptl file please ensure the model was exported for mobile" ); |
104 | } |
105 | PyTorchStreamReader reader(std::move(rai)); |
106 | auto bytecode_values = get_bytecode_ivalues(reader); |
107 | return _get_model_bytecode_version(bytecode_values); |
108 | } |
109 | |
110 | uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) { |
111 | TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format" ); |
112 | auto format = getFileFormat(data); |
113 | switch (format) { |
114 | case FileFormat::FlatbufferFileFormat: { |
115 | return get_bytecode_version_from_bytes(data); |
116 | } |
117 | case FileFormat::ZipFileFormat: { |
118 | auto rai = |
119 | std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size); |
120 | auto version = _get_model_bytecode_version_zip(std::move(rai)); |
121 | return version; |
122 | } |
123 | |
124 | default: |
125 | TORCH_CHECK(false, "Unrecognized data format" ); |
126 | } |
127 | } |
128 | |
129 | uint64_t _get_model_bytecode_version( |
130 | const std::vector<IValue>& bytecode_ivalues) { |
131 | if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) { |
132 | int64_t model_version = bytecode_ivalues[0].toInt(); |
133 | TORCH_CHECK( |
134 | model_version > 0, |
135 | "Expected model bytecode version > 0 got " , |
136 | model_version); |
137 | return static_cast<uint64_t>(model_version); |
138 | } |
139 | TORCH_CHECK(false, "Failed to get bytecode version." ); |
140 | } |
141 | |
142 | /********************** Operator Version **********************/ |
143 | |
144 | uint64_t _get_model_operator_version( |
145 | PyTorchStreamReader& reader); // Forward Declare |
146 | |
147 | uint64_t _get_model_operator_version(std::istream& in) { |
148 | std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in); |
149 | return _get_model_operator_version(std::move(rai)); |
150 | } |
151 | |
152 | uint64_t _get_model_operator_version(const std::string& filename) { |
153 | std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename); |
154 | return _get_model_operator_version(std::move(rai)); |
155 | } |
156 | |
157 | uint64_t _get_model_operator_version( |
158 | std::shared_ptr<ReadAdapterInterface> rai) { |
159 | if (!check_zip_file(rai)) { |
160 | TORCH_CHECK( |
161 | false, |
162 | "Failed to open .ptl file please ensure the model was exported for mobile" ); |
163 | } |
164 | PyTorchStreamReader reader(std::move(rai)); |
165 | return _get_model_operator_version(reader); |
166 | } |
167 | |
168 | uint64_t _get_model_operator_version(PyTorchStreamReader& reader) { |
169 | return reader.version(); |
170 | } |
171 | |
172 | /********************** Operators and Info **********************/ |
173 | |
174 | // Forward declare |
175 | std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
176 | std::vector<IValue> bytecode_ivalues); |
177 | |
178 | std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
179 | std::istream& in) { |
180 | std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in); |
181 | return _get_model_ops_and_info(std::move(rai)); |
182 | } |
183 | |
184 | std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
185 | const std::string& filename) { |
186 | std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename); |
187 | return _get_model_ops_and_info(std::move(rai)); |
188 | } |
189 | |
190 | std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
191 | std::shared_ptr<ReadAdapterInterface> rai) { |
192 | if (!check_zip_file(rai)) { |
193 | TORCH_WARN("Failed to open zip file for model ops." ); |
194 | return std::unordered_map<std::string, OperatorInfo>{}; |
195 | } |
196 | PyTorchStreamReader reader(std::move(rai)); |
197 | auto bytecode_values = get_bytecode_ivalues(reader); |
198 | return _get_model_ops_and_info(bytecode_values); |
199 | } |
200 | |
201 | /* A function to retrieve the root (top level) operators of a model and their |
202 | * corresponding compatibility info. These root operators can call other |
203 | * operators within them (traced ops), and a root op can call many different |
204 | * traced ops depending on internal code paths in the root op. These traced ops |
205 | * are not returned by this function. Those operators are abstracted into the |
206 | * runtime as an implementation detail (and the traced ops themselves can also |
207 | * call other operators) making retrieving them difficult and their value from |
208 | * this api negligible since they will differ between which runtime version the |
209 | * model is run on. Because of this, there is a false positive this api can't |
210 | * prevent in a compatibility usecase. All the root ops of a model are present |
211 | * in a target runtime, but not all the traced ops are which prevents a model |
212 | * from being able to run. |
213 | **/ |
214 | std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
215 | std::vector<IValue> bytecode_ivalues) { |
216 | constexpr uint64_t min_version_with_schema = 6; |
217 | if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) { |
218 | TORCH_WARN( |
219 | "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it" ); |
220 | } |
221 | std::unordered_map<std::string, OperatorInfo> result; |
222 | if (bytecode_ivalues.empty()) { |
223 | TORCH_WARN("Failed to get model ops and info." ); |
224 | return result; |
225 | } |
226 | // loop over all the functions in the bytecode |
227 | for (const auto i : c10::irange(1, bytecode_ivalues.size())) { |
228 | // descend to the operators list |
229 | const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); |
230 | auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1]; |
231 | auto operators = operators_tuple.toTupleRef().elements()[1]; |
232 | for (auto& op_tuple : operators.toTupleRef().elements()) { |
233 | const auto& op = op_tuple.toTupleRef().elements(); |
234 | |
235 | // grab name |
236 | std::string op_name = op.at(0).toStringRef(); |
237 | std::string op_overload_name = op.at(1).toStringRef(); |
238 | if (!op_overload_name.empty()) { |
239 | op_name.append("." ); |
240 | op_name.append(op_overload_name); |
241 | } |
242 | |
243 | // grab schema size |
244 | if (op.size() > 2) { |
245 | result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()}); |
246 | } else { // no schema information use default |
247 | result.emplace(op_name, OperatorInfo{}); |
248 | } |
249 | } |
250 | } |
251 | return result; |
252 | } |
253 | |
254 | /********************** Get Type Table **********************/ |
255 | |
256 | // Forward declare |
257 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
258 | const std::vector<IValue>& bytecode_ivalues); |
259 | |
260 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
261 | std::istream& in) { |
262 | std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in); |
263 | return _get_mobile_model_contained_types(std::move(rai)); |
264 | } |
265 | |
266 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
267 | const std::string& filename) { |
268 | std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename); |
269 | return _get_mobile_model_contained_types(std::move(rai)); |
270 | } |
271 | |
272 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
273 | std::shared_ptr<ReadAdapterInterface> rai) { |
274 | if (!check_zip_file(rai)) { |
275 | TORCH_CHECK( |
276 | false, |
277 | "Failed to open .ptl file please ensure the model was exported for mobile" ); |
278 | } |
279 | PyTorchStreamReader reader(std::move(rai)); |
280 | auto bytecode_values = get_bytecode_ivalues(reader); |
281 | return _get_mobile_model_contained_types(bytecode_values); |
282 | } |
283 | |
284 | // Get deduplicate type table given bytecode, and each string is a atomic type, |
285 | // like str, Tensor and etc. For example, |
286 | // input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" |
287 | // output: {Dict, int, Tuple, Tensor} |
288 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
289 | const std::vector<IValue>& bytecode_ivalues) { |
290 | std::unordered_set<std::string> contained_types; |
291 | // To avoid parsing same type twice, declare $parsed_type_names_records and |
292 | // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as |
293 | // the hash to record which types are parsed. |
294 | std::unordered_set<std::string> parsed_type_names_records; |
295 | for (const auto i : c10::irange(1, bytecode_ivalues.size())) { |
296 | const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements(); |
297 | auto type_table_tuple = |
298 | method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE]; |
299 | const auto& type_table = |
300 | type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements(); |
301 | |
302 | // type_table is a list of IValue, and each IValue is a string, |
303 | // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" |
304 | std::vector<std::string> type_name_list; |
305 | for (const auto& type_definition : type_table) { |
306 | std::unordered_set<std::string> type_tokens; |
307 | std::string type_name = type_definition.toStringRef(); |
308 | type_name_list.emplace_back(type_name); |
309 | } |
310 | at::TypeParser parser(type_name_list); |
311 | parser.parseList(); |
312 | contained_types = parser.getContainedTypes(); |
313 | } |
314 | |
315 | return contained_types; |
316 | } |
317 | |
318 | /********************** Compatibility Checker **********************/ |
319 | |
320 | ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) { |
321 | std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in); |
322 | return get(std::move(rai)); |
323 | } |
324 | |
325 | ModelCompatibilityInfo ModelCompatibilityInfo::get( |
326 | const std::string& filename) { |
327 | std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename); |
328 | return get(std::move(rai)); |
329 | } |
330 | |
331 | ModelCompatibilityInfo ModelCompatibilityInfo::get( |
332 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) { |
333 | if (!check_zip_file(rai)) { |
334 | TORCH_CHECK( |
335 | false, "Failed to open zip file for model compatibility information" ); |
336 | } |
337 | PyTorchStreamReader reader(std::move(rai)); |
338 | std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader); |
339 | uint64_t model_bytecode_version = |
340 | _get_model_bytecode_version(bytecode_values); |
341 | auto model_info = _get_model_ops_and_info(bytecode_values); |
342 | std::unordered_set<std::string> type_table = |
343 | _get_mobile_model_contained_types(bytecode_values); |
344 | uint64_t operator_version = _get_model_operator_version(reader); |
345 | return ModelCompatibilityInfo{ |
346 | model_bytecode_version, model_info, type_table, operator_version}; |
347 | } |
348 | |
349 | ModelCompatCheckResult is_compatible( |
350 | RuntimeCompatibilityInfo runtime_info, |
351 | ModelCompatibilityInfo model_info) { |
352 | ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}}; |
353 | // Check that the models bytecode version is less than or equal to |
354 | // kMaxSupportedBytecodeVersion from the runtime |
355 | if (model_info.bytecode_version > |
356 | runtime_info.min_max_supported_bytecode_version.second) { |
357 | result.status = ModelCompatibilityStatus::ERROR; |
358 | std::ostringstream s; |
359 | s << "model bytecode version " << model_info.bytecode_version |
360 | << "is greater than the max supported bytecode version in runtimes " |
361 | << runtime_info.min_max_supported_bytecode_version.second; |
362 | result.errors.emplace_back(s.str()); |
363 | } else if ( |
364 | model_info.bytecode_version < |
365 | runtime_info.min_max_supported_bytecode_version.first) { |
366 | result.status = ModelCompatibilityStatus::ERROR; |
367 | std::ostringstream s; |
368 | s << "model bytecode version " << model_info.bytecode_version |
369 | << "is less than the minimum supported bytecode version in runtime " |
370 | << runtime_info.min_max_supported_bytecode_version.first; |
371 | result.errors.emplace_back(s.str()); |
372 | } |
373 | |
374 | std::unordered_set<std::string> supported_type = runtime_info.supported_types; |
375 | |
376 | // Check type table |
377 | for (const auto& type_name : model_info.type_table) { |
378 | if (supported_type.find(type_name) == supported_type.end()) { |
379 | result.status = ModelCompatibilityStatus::ERROR; |
380 | std::ostringstream s; |
381 | s << "Primitive type: '" << type_name |
382 | << "' is not supported in current runtime" ; |
383 | result.errors.push_back(s.str()); |
384 | } |
385 | } |
386 | |
387 | // Check operators |
388 | std::unordered_map<std::string, OperatorInfo> operator_info = |
389 | model_info.operator_info; |
390 | for (auto const& op : operator_info) { |
391 | std::string op_name = op.first; |
392 | OperatorInfo model_op_info = op.second; |
393 | |
394 | // Check if operator not present in runtime |
395 | if (runtime_info.operator_info.find(op_name) == |
396 | runtime_info.operator_info.end()) { |
397 | result.status = ModelCompatibilityStatus::ERROR; |
398 | std::ostringstream s; |
399 | s << "Operator '" << op_name << "' missing from runtime (not found)" ; |
400 | result.errors.push_back(s.str()); |
401 | } else { |
402 | OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name); |
403 | |
404 | // If the runtime op has no schema information its a false alarm and isn't |
405 | // actually useable |
406 | if (!runtime_op_info.num_schema_args.has_value()) { |
407 | result.status = ModelCompatibilityStatus::ERROR; |
408 | std::ostringstream s; |
409 | s << "Operator '" << op_name |
410 | << "' missing from runtime (missing schema)" ; |
411 | result.errors.push_back(s.str()); |
412 | } else { |
413 | // Check if the model operator has schema information. If it doesn't |
414 | // then the model is from a bytecode version < 6 and we are done. If the |
415 | // model has more args than the runtime, then the runtime can't know |
416 | // what to do so we aren't compatible. If the runtime has more args than |
417 | // the model then we can just use default values and be fine. |
418 | if (model_op_info.num_schema_args.has_value() && |
419 | (model_op_info.num_schema_args.value() > |
420 | runtime_op_info.num_schema_args.value())) { |
421 | result.status = ModelCompatibilityStatus::ERROR; |
422 | std::ostringstream s; |
423 | s << "Operator schema for'" << op_name << "' has " |
424 | << model_op_info.num_schema_args.value() |
425 | << " args in model but only " |
426 | << runtime_op_info.num_schema_args.value() << " in the runtime" ; |
427 | result.errors.push_back(s.str()); |
428 | } |
429 | } |
430 | } |
431 | } |
432 | |
433 | // Check Operator Versions |
434 | if (model_info.operator_version < |
435 | runtime_info.min_max_supported_opperator_versions.first || |
436 | model_info.operator_version > |
437 | runtime_info.min_max_supported_opperator_versions.second) { |
438 | result.status = ModelCompatibilityStatus::ERROR; |
439 | std::ostringstream s; |
440 | s << "Model Operator Version " << model_info.operator_version |
441 | << "is not within supported version range of the runtime " |
442 | << runtime_info.min_max_supported_opperator_versions.first << " to " |
443 | << runtime_info.min_max_supported_opperator_versions.second; |
444 | result.errors.push_back(s.str()); |
445 | } |
446 | |
447 | return result; |
448 | } |
449 | |
450 | } // namespace jit |
451 | } // namespace torch |
452 | |