1#pragma once
2
3#include <c10/macros/Export.h>
4#include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
5
6#include <istream>
7#include <memory>
8#include <unordered_map>
9
10namespace caffe2 {
11namespace serialize {
12class PyTorchStreamReader;
13class ReadAdapterInterface;
14} // namespace serialize
15} // namespace caffe2
16
17namespace torch {
18namespace jit {
19
20// The family of methods below to get bytecode version from a model
21// Throws if not passed in a well formed model
22TORCH_API uint64_t _get_model_bytecode_version(std::istream& in);
23
24TORCH_API uint64_t _get_model_bytecode_version(const std::string& filename);
25
26TORCH_API uint64_t _get_model_bytecode_version(
27 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
28
29uint64_t _get_model_bytecode_version(
30 const std::vector<c10::IValue>& bytecode_ivalues);
31
32// The family of methods below to get the operator version from a model
33// Throws if not passed in a well formed model
34TORCH_API uint64_t _get_model_operator_version(std::istream& in);
35
36TORCH_API uint64_t _get_model_operator_version(const std::string& filename);
37
38TORCH_API uint64_t _get_model_operator_version(
39 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
40
41// Utility Functions
42std::vector<c10::IValue> get_bytecode_ivalues(
43 caffe2::serialize::PyTorchStreamReader& reader);
44
45c10::IValue readArchive(
46 const std::string& archive_name,
47 caffe2::serialize::PyTorchStreamReader& stream_reader);
48
49bool check_zip_file(
50 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
51
52// The family of methods below to get the root ops and information from a model
53TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
54 std::istream& in);
55
56TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
57 const std::string& filename);
58
59TORCH_API std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
60 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
61
62// The family of methods below to get contained types from a model
63// Throws if not passed in a well formed model
64TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
65 std::istream& in);
66
67TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
68 const std::string& filename);
69
70TORCH_API std::unordered_set<std::string> _get_mobile_model_contained_types(
71 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
72
73std::unordered_set<std::string> _get_mobile_model_contained_types(
74 const std::vector<c10::IValue>& bytecode_ivalues);
75
76// The family of methods below return the compatibility information of a model
77struct ModelCompatibilityInfo {
78 uint64_t bytecode_version;
79 std::unordered_map<std::string, OperatorInfo> operator_info;
80 std::unordered_set<std::string> type_table;
81 uint64_t operator_version;
82
83 // Factory Methods
84 static TORCH_API ModelCompatibilityInfo get(std::istream& in);
85 static TORCH_API ModelCompatibilityInfo get(const std::string& filename);
86 static TORCH_API ModelCompatibilityInfo
87 get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai);
88};
89
90enum ModelCompatibilityStatus {
91 OK = 1,
92 ERROR = 2,
93};
94
95struct ModelCompatCheckResult {
96 ModelCompatibilityStatus status;
97 std::vector<std::string> errors;
98};
99// Takes in information about a runtime and a model and returns if the two are
100// compatible with one another.
101TORCH_API ModelCompatCheckResult is_compatible(
102 RuntimeCompatibilityInfo runtime_info,
103 ModelCompatibilityInfo model_info);
104
105} // namespace jit
106} // namespace torch
107