1#include <ATen/core/ivalue.h>
2#include <caffe2/serialize/file_adapter.h>
3#include <caffe2/serialize/inline_container.h>
4#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
5#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
6#include <torch/csrc/jit/mobile/file_format.h>
7#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
8#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
9#include <torch/csrc/jit/mobile/type_parser.h>
10#include <torch/csrc/jit/serialization/import_export_constants.h>
11#include <torch/csrc/jit/serialization/import_read.h>
12
13#include <caffe2/serialize/in_memory_adapter.h>
14#include <sstream>
15#include <string>
16#include <unordered_set>
17#include <vector>
18
19namespace c10 {
20TypePtr parseType(const std::string& pythonStr);
21} // namespace c10
22
23namespace torch {
24namespace jit {
25
26using caffe2::serialize::FileAdapter;
27using caffe2::serialize::IStreamAdapter;
28using caffe2::serialize::PyTorchStreamReader;
29using caffe2::serialize::ReadAdapterInterface;
30
31c10::IValue readArchive(
32 const std::string& archive_name,
33 PyTorchStreamReader& stream_reader) {
34 c10::optional<at::Device> device;
35 std::shared_ptr<CompilationUnit> compilation_unit =
36 std::make_shared<CompilationUnit>();
37
38 // TODO (T90180710): Simplify type_resolver and obj_loader when getting
39 // bytecode version from model
40 auto type_resolver = [&](const c10::QualifiedName& qn) {
41 return typeResolverMobile(qn, compilation_unit);
42 };
43
44 std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
45 std::make_shared<mobile::CompilationUnit>();
46 auto obj_loader = [&](at::StrongTypePtr type, IValue input) {
47 return objLoaderMobile(type, input, *mobile_compilation_unit);
48 };
49 bool bytecode_tensor_in_constants_archive =
50 (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
51 auto ivalues = torch::jit::readArchiveAndTensors(
52 archive_name,
53 /*pickle_prefix=*/"",
54 /*tensor_prefix=*/
55 bytecode_tensor_in_constants_archive ? "constants/" : "",
56 type_resolver,
57 obj_loader,
58 device,
59 stream_reader,
60 nullptr);
61 return ivalues;
62}
63
64std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
65 return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
66}
67
68/********************** Bytecode **********************/
69
70// Forward declare
71uint64_t _get_model_bytecode_version(
72 const std::vector<IValue>& bytecode_ivalues);
73static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
74
75uint64_t _get_model_bytecode_version(std::istream& in) {
76 auto orig_pos = in.tellg();
77 in.seekg(0, in.beg);
78 std::shared_ptr<char> data;
79 size_t size = 0;
80 std::tie(data, size) = get_stream_content(in);
81 in.seekg(orig_pos, in.beg);
82 return _get_model_bytecode_version_from_bytes(data.get(), size);
83}
84
85uint64_t _get_model_bytecode_version(const std::string& filename) {
86 std::ifstream ifile(filename);
87 return _get_model_bytecode_version(ifile);
88}
89
90uint64_t _get_model_bytecode_version(
91 std::shared_ptr<ReadAdapterInterface> rai) {
92 std::shared_ptr<char> data;
93 size_t size = 0;
94 std::tie(data, size) = get_rai_content(rai.get());
95 return _get_model_bytecode_version_from_bytes(data.get(), size);
96}
97
98uint64_t _get_model_bytecode_version_zip(
99 std::shared_ptr<ReadAdapterInterface> rai) {
100 if (!check_zip_file(rai)) {
101 TORCH_CHECK(
102 false,
103 "Failed to open .ptl file please ensure the model was exported for mobile");
104 }
105 PyTorchStreamReader reader(std::move(rai));
106 auto bytecode_values = get_bytecode_ivalues(reader);
107 return _get_model_bytecode_version(bytecode_values);
108}
109
110uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
111 TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
112 auto format = getFileFormat(data);
113 switch (format) {
114 case FileFormat::FlatbufferFileFormat: {
115 return get_bytecode_version_from_bytes(data);
116 }
117 case FileFormat::ZipFileFormat: {
118 auto rai =
119 std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
120 auto version = _get_model_bytecode_version_zip(std::move(rai));
121 return version;
122 }
123
124 default:
125 TORCH_CHECK(false, "Unrecognized data format");
126 }
127}
128
129uint64_t _get_model_bytecode_version(
130 const std::vector<IValue>& bytecode_ivalues) {
131 if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
132 int64_t model_version = bytecode_ivalues[0].toInt();
133 TORCH_CHECK(
134 model_version > 0,
135 "Expected model bytecode version > 0 got ",
136 model_version);
137 return static_cast<uint64_t>(model_version);
138 }
139 TORCH_CHECK(false, "Failed to get bytecode version.");
140}
141
142/********************** Operator Version **********************/
143
144uint64_t _get_model_operator_version(
145 PyTorchStreamReader& reader); // Forward Declare
146
147uint64_t _get_model_operator_version(std::istream& in) {
148 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
149 return _get_model_operator_version(std::move(rai));
150}
151
152uint64_t _get_model_operator_version(const std::string& filename) {
153 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
154 return _get_model_operator_version(std::move(rai));
155}
156
157uint64_t _get_model_operator_version(
158 std::shared_ptr<ReadAdapterInterface> rai) {
159 if (!check_zip_file(rai)) {
160 TORCH_CHECK(
161 false,
162 "Failed to open .ptl file please ensure the model was exported for mobile");
163 }
164 PyTorchStreamReader reader(std::move(rai));
165 return _get_model_operator_version(reader);
166}
167
168uint64_t _get_model_operator_version(PyTorchStreamReader& reader) {
169 return reader.version();
170}
171
172/********************** Operators and Info **********************/
173
174// Forward declare
175std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
176 std::vector<IValue> bytecode_ivalues);
177
178std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
179 std::istream& in) {
180 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
181 return _get_model_ops_and_info(std::move(rai));
182}
183
184std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
185 const std::string& filename) {
186 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
187 return _get_model_ops_and_info(std::move(rai));
188}
189
190std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
191 std::shared_ptr<ReadAdapterInterface> rai) {
192 if (!check_zip_file(rai)) {
193 TORCH_WARN("Failed to open zip file for model ops.");
194 return std::unordered_map<std::string, OperatorInfo>{};
195 }
196 PyTorchStreamReader reader(std::move(rai));
197 auto bytecode_values = get_bytecode_ivalues(reader);
198 return _get_model_ops_and_info(bytecode_values);
199}
200
201/* A function to retrieve the root (top level) operators of a model and their
202 * corresponding compatibility info. These root operators can call other
203 * operators within them (traced ops), and a root op can call many different
204 * traced ops depending on internal code paths in the root op. These traced ops
205 * are not returned by this function. Those operators are abstracted into the
206 * runtime as an implementation detail (and the traced ops themselves can also
207 * call other operators) making retrieving them difficult and their value from
208 * this api negligible since they will differ between which runtime version the
209 * model is run on. Because of this, there is a false positive this api can't
210 * prevent in a compatibility usecase. All the root ops of a model are present
211 * in a target runtime, but not all the traced ops are which prevents a model
212 * from being able to run.
213 **/
214std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
215 std::vector<IValue> bytecode_ivalues) {
216 constexpr uint64_t min_version_with_schema = 6;
217 if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
218 TORCH_WARN(
219 "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
220 }
221 std::unordered_map<std::string, OperatorInfo> result;
222 if (bytecode_ivalues.empty()) {
223 TORCH_WARN("Failed to get model ops and info.");
224 return result;
225 }
226 // loop over all the functions in the bytecode
227 for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
228 // descend to the operators list
229 const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
230 auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
231 auto operators = operators_tuple.toTupleRef().elements()[1];
232 for (auto& op_tuple : operators.toTupleRef().elements()) {
233 const auto& op = op_tuple.toTupleRef().elements();
234
235 // grab name
236 std::string op_name = op.at(0).toStringRef();
237 std::string op_overload_name = op.at(1).toStringRef();
238 if (!op_overload_name.empty()) {
239 op_name.append(".");
240 op_name.append(op_overload_name);
241 }
242
243 // grab schema size
244 if (op.size() > 2) {
245 result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
246 } else { // no schema information use default
247 result.emplace(op_name, OperatorInfo{});
248 }
249 }
250 }
251 return result;
252}
253
254/********************** Get Type Table **********************/
255
256// Forward declare
257std::unordered_set<std::string> _get_mobile_model_contained_types(
258 const std::vector<IValue>& bytecode_ivalues);
259
260std::unordered_set<std::string> _get_mobile_model_contained_types(
261 std::istream& in) {
262 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
263 return _get_mobile_model_contained_types(std::move(rai));
264}
265
266std::unordered_set<std::string> _get_mobile_model_contained_types(
267 const std::string& filename) {
268 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
269 return _get_mobile_model_contained_types(std::move(rai));
270}
271
272std::unordered_set<std::string> _get_mobile_model_contained_types(
273 std::shared_ptr<ReadAdapterInterface> rai) {
274 if (!check_zip_file(rai)) {
275 TORCH_CHECK(
276 false,
277 "Failed to open .ptl file please ensure the model was exported for mobile");
278 }
279 PyTorchStreamReader reader(std::move(rai));
280 auto bytecode_values = get_bytecode_ivalues(reader);
281 return _get_mobile_model_contained_types(bytecode_values);
282}
283
284// Get deduplicate type table given bytecode, and each string is a atomic type,
285// like str, Tensor and etc. For example,
286// input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
287// output: {Dict, int, Tuple, Tensor}
288std::unordered_set<std::string> _get_mobile_model_contained_types(
289 const std::vector<IValue>& bytecode_ivalues) {
290 std::unordered_set<std::string> contained_types;
291 // To avoid parsing same type twice, declare $parsed_type_names_records and
292 // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
293 // the hash to record which types are parsed.
294 std::unordered_set<std::string> parsed_type_names_records;
295 for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
296 const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
297 auto type_table_tuple =
298 method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
299 const auto& type_table =
300 type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
301
302 // type_table is a list of IValue, and each IValue is a string,
303 // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
304 std::vector<std::string> type_name_list;
305 for (const auto& type_definition : type_table) {
306 std::unordered_set<std::string> type_tokens;
307 std::string type_name = type_definition.toStringRef();
308 type_name_list.emplace_back(type_name);
309 }
310 at::TypeParser parser(type_name_list);
311 parser.parseList();
312 contained_types = parser.getContainedTypes();
313 }
314
315 return contained_types;
316}
317
318/********************** Compatibility Checker **********************/
319
320ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
321 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
322 return get(std::move(rai));
323}
324
325ModelCompatibilityInfo ModelCompatibilityInfo::get(
326 const std::string& filename) {
327 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
328 return get(std::move(rai));
329}
330
331ModelCompatibilityInfo ModelCompatibilityInfo::get(
332 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
333 if (!check_zip_file(rai)) {
334 TORCH_CHECK(
335 false, "Failed to open zip file for model compatibility information");
336 }
337 PyTorchStreamReader reader(std::move(rai));
338 std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
339 uint64_t model_bytecode_version =
340 _get_model_bytecode_version(bytecode_values);
341 auto model_info = _get_model_ops_and_info(bytecode_values);
342 std::unordered_set<std::string> type_table =
343 _get_mobile_model_contained_types(bytecode_values);
344 uint64_t operator_version = _get_model_operator_version(reader);
345 return ModelCompatibilityInfo{
346 model_bytecode_version, model_info, type_table, operator_version};
347}
348
349ModelCompatCheckResult is_compatible(
350 RuntimeCompatibilityInfo runtime_info,
351 ModelCompatibilityInfo model_info) {
352 ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
353 // Check that the models bytecode version is less than or equal to
354 // kMaxSupportedBytecodeVersion from the runtime
355 if (model_info.bytecode_version >
356 runtime_info.min_max_supported_bytecode_version.second) {
357 result.status = ModelCompatibilityStatus::ERROR;
358 std::ostringstream s;
359 s << "model bytecode version " << model_info.bytecode_version
360 << "is greater than the max supported bytecode version in runtimes "
361 << runtime_info.min_max_supported_bytecode_version.second;
362 result.errors.emplace_back(s.str());
363 } else if (
364 model_info.bytecode_version <
365 runtime_info.min_max_supported_bytecode_version.first) {
366 result.status = ModelCompatibilityStatus::ERROR;
367 std::ostringstream s;
368 s << "model bytecode version " << model_info.bytecode_version
369 << "is less than the minimum supported bytecode version in runtime "
370 << runtime_info.min_max_supported_bytecode_version.first;
371 result.errors.emplace_back(s.str());
372 }
373
374 std::unordered_set<std::string> supported_type = runtime_info.supported_types;
375
376 // Check type table
377 for (const auto& type_name : model_info.type_table) {
378 if (supported_type.find(type_name) == supported_type.end()) {
379 result.status = ModelCompatibilityStatus::ERROR;
380 std::ostringstream s;
381 s << "Primitive type: '" << type_name
382 << "' is not supported in current runtime";
383 result.errors.push_back(s.str());
384 }
385 }
386
387 // Check operators
388 std::unordered_map<std::string, OperatorInfo> operator_info =
389 model_info.operator_info;
390 for (auto const& op : operator_info) {
391 std::string op_name = op.first;
392 OperatorInfo model_op_info = op.second;
393
394 // Check if operator not present in runtime
395 if (runtime_info.operator_info.find(op_name) ==
396 runtime_info.operator_info.end()) {
397 result.status = ModelCompatibilityStatus::ERROR;
398 std::ostringstream s;
399 s << "Operator '" << op_name << "' missing from runtime (not found)";
400 result.errors.push_back(s.str());
401 } else {
402 OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);
403
404 // If the runtime op has no schema information its a false alarm and isn't
405 // actually useable
406 if (!runtime_op_info.num_schema_args.has_value()) {
407 result.status = ModelCompatibilityStatus::ERROR;
408 std::ostringstream s;
409 s << "Operator '" << op_name
410 << "' missing from runtime (missing schema)";
411 result.errors.push_back(s.str());
412 } else {
413 // Check if the model operator has schema information. If it doesn't
414 // then the model is from a bytecode version < 6 and we are done. If the
415 // model has more args than the runtime, then the runtime can't know
416 // what to do so we aren't compatible. If the runtime has more args than
417 // the model then we can just use default values and be fine.
418 if (model_op_info.num_schema_args.has_value() &&
419 (model_op_info.num_schema_args.value() >
420 runtime_op_info.num_schema_args.value())) {
421 result.status = ModelCompatibilityStatus::ERROR;
422 std::ostringstream s;
423 s << "Operator schema for'" << op_name << "' has "
424 << model_op_info.num_schema_args.value()
425 << " args in model but only "
426 << runtime_op_info.num_schema_args.value() << " in the runtime";
427 result.errors.push_back(s.str());
428 }
429 }
430 }
431 }
432
433 // Check Operator Versions
434 if (model_info.operator_version <
435 runtime_info.min_max_supported_opperator_versions.first ||
436 model_info.operator_version >
437 runtime_info.min_max_supported_opperator_versions.second) {
438 result.status = ModelCompatibilityStatus::ERROR;
439 std::ostringstream s;
440 s << "Model Operator Version " << model_info.operator_version
441 << "is not within supported version range of the runtime "
442 << runtime_info.min_max_supported_opperator_versions.first << " to "
443 << runtime_info.min_max_supported_opperator_versions.second;
444 result.errors.push_back(s.str());
445 }
446
447 return result;
448}
449
450} // namespace jit
451} // namespace torch
452