1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/core/ivalue.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | // Used in torch.package and TorchScript serialization to coordinate |
9 | // sharing of storages between models. Also used to create deterministic |
10 | // naming for storages. |
11 | class TORCH_API SerializationStorageContext { |
12 | public: |
13 | explicit SerializationStorageContext() = default; |
14 | SerializationStorageContext operator=(const SerializationStorageContext&) = |
15 | delete; |
16 | SerializationStorageContext(const SerializationStorageContext&) = delete; |
17 | |
18 | uint64_t getOrAddStorage(c10::Storage storage) { |
19 | if (!hasStorage(storage)) { |
20 | uint64_t size = storage_id_map_.size(); |
21 | storage_id_map_[storage] = size; |
22 | } |
23 | return storage_id_map_[storage]; |
24 | } |
25 | |
26 | bool hasStorage(c10::Storage storage) { |
27 | return storage_id_map_.find(storage) != storage_id_map_.end(); |
28 | } |
29 | |
30 | ~SerializationStorageContext() = default; |
31 | |
32 | private: |
33 | class StorageSerializationHash { |
34 | public: |
35 | size_t operator()(const c10::Storage& storage) const { |
36 | return std::hash<void*>()( |
37 | reinterpret_cast<void*>(storage.unsafeGetStorageImpl())); |
38 | } |
39 | }; |
40 | |
41 | class StorageSerializationEqual { |
42 | public: |
43 | bool operator()(const c10::Storage& lhs, const c10::Storage& rhs) const { |
44 | return lhs.unsafeGetStorageImpl() == rhs.unsafeGetStorageImpl(); |
45 | } |
46 | }; |
47 | |
48 | std::unordered_map< |
49 | c10::Storage, |
50 | uint64_t, |
51 | StorageSerializationHash, |
52 | StorageSerializationEqual> |
53 | storage_id_map_; |
54 | }; |
55 | |
56 | // Used in torch.package and TorchScript deserialization to coordinate |
57 | // sharing of storages between models. |
58 | class TORCH_API DeserializationStorageContext { |
59 | public: |
60 | explicit DeserializationStorageContext() = default; |
61 | DeserializationStorageContext operator=( |
62 | const DeserializationStorageContext&) = delete; |
63 | DeserializationStorageContext(const DeserializationStorageContext&) = delete; |
64 | |
65 | void addStorage(const std::string& name, c10::Storage storage) { |
66 | TORCH_INTERNAL_ASSERT(!hasStorage(name)); |
67 | name_storage_map_.insert({name, storage}); |
68 | } |
69 | |
70 | bool hasStorage(const std::string& name) { |
71 | return name_storage_map_.find(name) != name_storage_map_.end(); |
72 | } |
73 | |
74 | c10::Storage getStorage(const std::string& name) { |
75 | TORCH_INTERNAL_ASSERT(hasStorage(name)); |
76 | return name_storage_map_.find(name)->second; |
77 | } |
78 | ~DeserializationStorageContext() = default; |
79 | |
80 | private: |
81 | std::unordered_map<std::string, c10::Storage> name_storage_map_; |
82 | }; |
83 | |
84 | } // namespace jit |
85 | } // namespace torch |
86 |