1 | #pragma once |
---|---|
2 | |
3 | #include <c10/macros/Export.h> |
4 | #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h> |
5 | |
6 | #include <istream> |
7 | #include <memory> |
8 | #include <unordered_map> |
9 | |
10 | namespace caffe2 { |
11 | namespace serialize { |
12 | class PyTorchStreamReader; |
13 | class ReadAdapterInterface; |
14 | } // namespace serialize |
15 | } // namespace caffe2 |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | // The family of methods below to get bytecode version from a model |
21 | // Throws if not passed in a well formed model |
22 | TORCH_API uint64_t _get_model_bytecode_version(std::istream& in); |
23 | |
24 | TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename); |
25 | |
26 | TORCH_API uint64_t _get_model_bytecode_version( |
27 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
28 | |
29 | uint64_t _get_model_bytecode_version( |
30 | const std::vector<c10::IValue>& bytecode_ivalues); |
31 | |
32 | // The family of methods below to get the operator version from a model |
33 | // Throws if not passed in a well formed model |
34 | TORCH_API uint64_t _get_model_operator_version(std::istream& in); |
35 | |
36 | TORCH_API uint64_t _get_model_operator_version(const std::string& filename); |
37 | |
38 | TORCH_API uint64_t _get_model_operator_version( |
39 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
40 | |
41 | // Utility Functions |
42 | std::vector<c10::IValue> get_bytecode_ivalues( |
43 | caffe2::serialize::PyTorchStreamReader& reader); |
44 | |
45 | c10::IValue readArchive( |
46 | const std::string& archive_name, |
47 | caffe2::serialize::PyTorchStreamReader& stream_reader); |
48 | |
49 | bool check_zip_file( |
50 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
51 | |
52 | // The family of methods below to get the root ops and information from a model |
53 | TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
54 | std::istream& in); |
55 | |
56 | TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
57 | const std::string& filename); |
58 | |
59 | TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info( |
60 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
61 | |
62 | // The family of methods below to get contained types from a model |
63 | // Throws if not passed in a well formed model |
64 | TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( |
65 | std::istream& in); |
66 | |
67 | TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( |
68 | const std::string& filename); |
69 | |
70 | TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types( |
71 | std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
72 | |
73 | std::unordered_set<std::string> _get_mobile_model_contained_types( |
74 | const std::vector<c10::IValue>& bytecode_ivalues); |
75 | |
76 | // The family of methods below return the compatibility information of a model |
77 | struct ModelCompatibilityInfo { |
78 | uint64_t bytecode_version; |
79 | std::unordered_map<std::string, OperatorInfo> operator_info; |
80 | std::unordered_set<std::string> type_table; |
81 | uint64_t operator_version; |
82 | |
83 | // Factory Methods |
84 | static TORCH_API ModelCompatibilityInfo get(std::istream& in); |
85 | static TORCH_API ModelCompatibilityInfo get(const std::string& filename); |
86 | static TORCH_API ModelCompatibilityInfo |
87 | get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai); |
88 | }; |
89 | |
90 | enum ModelCompatibilityStatus { |
91 | OK = 1, |
92 | ERROR = 2, |
93 | }; |
94 | |
95 | struct ModelCompatCheckResult { |
96 | ModelCompatibilityStatus status; |
97 | std::vector<std::string> errors; |
98 | }; |
99 | // Takes in information about a runtime and a model and returns if the two are |
100 | // compatible with one another. |
101 | TORCH_API ModelCompatCheckResult is_compatible( |
102 | RuntimeCompatibilityInfo runtime_info, |
103 | ModelCompatibilityInfo model_info); |
104 | |
105 | } // namespace jit |
106 | } // namespace torch |
107 |